Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

JaxStackTraceBeforeTransformation error with parametrized ODE #513

Open
SnowOwl-Hedwig opened this issue Oct 11, 2024 · 4 comments
Open
Labels
question User queries

Comments

@SnowOwl-Hedwig
Copy link

Hi,

based on this tutorial I tried to get started with Jax and neural ODEs: https://colab.research.google.com/drive/1ZlK36VgWy1vBjBNXjSUg6Cb-7zeoa3jh

However, I get the a JaxStackTraceBeforeTransformation error (detailed error message below). I boiled down the code to a small working example (also provided below) and noted the error only occors when the equation in test_func contains an argument.
Since this issue seemed similar to one raised in an earlier post (jax-ml/jax#13629) I tried downgrading jax to version 0.4.23. I also tried setting up a fresh python environment with only the necessary packages installed. Nothing helped, so far. I'd appreciate your help :)

(Even though it's labeled JaxStack... error, @dfm pointed out it might actually be a problem with diffrax: "The error reported here is actually a TypeError being raised because of an issue with the return types in a jax.custom_jvp. It's hard to see from this error report exactly which custom_jvp is the culprit, but it seems like it must be something within diffrax or equinox, so I'd recommend opening the issue on the https://github.com/patrick-kidger/diffrax issue tracker." jax-ml/jax#24253)

Working example:

from diffrax import diffeqsolve, ODETerm, SaveAt, Tsit5
import jax
import jax.numpy as jnp

def f(t, y, _):
  dp_dt = 0.9 * y
  return dp_dt
    
b0 = 2  # init condition
data_ts = jnp.linspace(0, 20, 100)
data_sol = diffeqsolve(ODETerm(f), Tsit5(), t0=0, t1=20, dt0=0.01,
                       y0=(b0), saveat=SaveAt(ts=data_ts))

def fwd_test(coeff):
    num_ts = 100
    def test_func(t, y, _coeff):
        dp_dt = y * _coeff #doesn't work
        # dp_dt = y #works
        return dp_dt
    
    b0 = 2
    model_ts = jnp.linspace(0, 20, num_ts)
    # Note: larger dt0 so that it runs faster; this is about as large as it can go
    model_sol = diffeqsolve(ODETerm(test_func), Tsit5(), t0=0, t1=20, dt0=0.5,
                        y0=(b0), args=coeff, saveat=SaveAt(ts=model_ts))
    model_b = model_sol.ys
    data_b = data_sol.ys
    return jnp.sum((model_b - data_b)**2)

coeff = 1.
grads = jax.grad(fwd_test)(coeff)

Error message:

JaxStackTraceBeforeTransformation         Traceback (most recent call last)
File <frozen runpy>:198, in _run_module_as_main()

File <frozen runpy>:88, in _run_code()

File [~\AppData\Local\Programs\Python\Python311\Lib\site-packages\ipykernel_launcher.py:18](http://localhost:8888/lab/tree/ML%20for%20Astro/day3/~/AppData/Local/Programs/Python/Python311/Lib/site-packages/ipykernel_launcher.py#line=17)
     16 from ipykernel import kernelapp as app
---> 18 app.launch_new_instance()

File [~\AppData\Local\Programs\Python\Python311\Lib\site-packages\traitlets\config\application.py:1075](http://localhost:8888/lab/tree/ML%20for%20Astro/day3/~/AppData/Local/Programs/Python/Python311/Lib/site-packages/traitlets/config/application.py#line=1074), in launch_instance()
   1074 app.initialize(argv)
-> 1075 app.start()

File [~\AppData\Local\Programs\Python\Python311\Lib\site-packages\ipykernel\kernelapp.py:739](http://localhost:8888/lab/tree/ML%20for%20Astro/day3/~/AppData/Local/Programs/Python/Python311/Lib/site-packages/ipykernel/kernelapp.py#line=738), in start()
    738 try:
--> 739     self.io_loop.start()
    740 except KeyboardInterrupt:

File [~\AppData\Local\Programs\Python\Python311\Lib\site-packages\tornado\platform\asyncio.py:205](http://localhost:8888/lab/tree/ML%20for%20Astro/day3/~/AppData/Local/Programs/Python/Python311/Lib/site-packages/tornado/platform/asyncio.py#line=204), in start()
    204 def start(self) -> None:
--> 205     self.asyncio_loop.run_forever()

File [~\AppData\Local\Programs\Python\Python311\Lib\asyncio\base_events.py:607](http://localhost:8888/lab/tree/ML%20for%20Astro/day3/~/AppData/Local/Programs/Python/Python311/Lib/asyncio/base_events.py#line=606), in run_forever()
    606 while True:
--> 607     self._run_once()
    608     if self._stopping:

File [~\AppData\Local\Programs\Python\Python311\Lib\asyncio\base_events.py:1919](http://localhost:8888/lab/tree/ML%20for%20Astro/day3/~/AppData/Local/Programs/Python/Python311/Lib/asyncio/base_events.py#line=1918), in _run_once()
   1918     else:
-> 1919         handle._run()
   1920 handle = None

File [~\AppData\Local\Programs\Python\Python311\Lib\asyncio\events.py:80](http://localhost:8888/lab/tree/ML%20for%20Astro/day3/~/AppData/Local/Programs/Python/Python311/Lib/asyncio/events.py#line=79), in _run()
     79 try:
---> 80     self._context.run(self._callback, *self._args)
     81 except (SystemExit, KeyboardInterrupt):

File [~\AppData\Local\Programs\Python\Python311\Lib\site-packages\ipykernel\kernelbase.py:545](http://localhost:8888/lab/tree/ML%20for%20Astro/day3/~/AppData/Local/Programs/Python/Python311/Lib/site-packages/ipykernel/kernelbase.py#line=544), in dispatch_queue()
    544 try:
--> 545     await self.process_one()
    546 except Exception:

File [~\AppData\Local\Programs\Python\Python311\Lib\site-packages\ipykernel\kernelbase.py:534](http://localhost:8888/lab/tree/ML%20for%20Astro/day3/~/AppData/Local/Programs/Python/Python311/Lib/site-packages/ipykernel/kernelbase.py#line=533), in process_one()
    533         return
--> 534 await dispatch(*args)

File [~\AppData\Local\Programs\Python\Python311\Lib\site-packages\ipykernel\kernelbase.py:437](http://localhost:8888/lab/tree/ML%20for%20Astro/day3/~/AppData/Local/Programs/Python/Python311/Lib/site-packages/ipykernel/kernelbase.py#line=436), in dispatch_shell()
    436     if inspect.isawaitable(result):
--> 437         await result
    438 except Exception:

File [~\AppData\Local\Programs\Python\Python311\Lib\site-packages\ipykernel\ipkernel.py:362](http://localhost:8888/lab/tree/ML%20for%20Astro/day3/~/AppData/Local/Programs/Python/Python311/Lib/site-packages/ipykernel/ipkernel.py#line=361), in execute_request()
    361 self._associate_new_top_level_threads_with(parent_header)
--> 362 await super().execute_request(stream, ident, parent)

File [~\AppData\Local\Programs\Python\Python311\Lib\site-packages\ipykernel\kernelbase.py:778](http://localhost:8888/lab/tree/ML%20for%20Astro/day3/~/AppData/Local/Programs/Python/Python311/Lib/site-packages/ipykernel/kernelbase.py#line=777), in execute_request()
    777 if inspect.isawaitable(reply_content):
--> 778     reply_content = await reply_content
    780 # Flush output before sending the reply.

File [~\AppData\Local\Programs\Python\Python311\Lib\site-packages\ipykernel\ipkernel.py:449](http://localhost:8888/lab/tree/ML%20for%20Astro/day3/~/AppData/Local/Programs/Python/Python311/Lib/site-packages/ipykernel/ipkernel.py#line=448), in do_execute()
    448 if accepts_params["cell_id"]:
--> 449     res = shell.run_cell(
    450         code,
    451         store_history=store_history,
    452         silent=silent,
    453         cell_id=cell_id,
    454     )
    455 else:

File [~\AppData\Local\Programs\Python\Python311\Lib\site-packages\ipykernel\zmqshell.py:549](http://localhost:8888/lab/tree/ML%20for%20Astro/day3/~/AppData/Local/Programs/Python/Python311/Lib/site-packages/ipykernel/zmqshell.py#line=548), in run_cell()
    548 self._last_traceback = None
--> 549 return super().run_cell(*args, **kwargs)

File [~\AppData\Local\Programs\Python\Python311\Lib\site-packages\IPython\core\interactiveshell.py:3075](http://localhost:8888/lab/tree/ML%20for%20Astro/day3/~/AppData/Local/Programs/Python/Python311/Lib/site-packages/IPython/core/interactiveshell.py#line=3074), in run_cell()
   3074 try:
-> 3075     result = self._run_cell(
   3076         raw_cell, store_history, silent, shell_futures, cell_id
   3077     )
   3078 finally:

File [~\AppData\Local\Programs\Python\Python311\Lib\site-packages\IPython\core\interactiveshell.py:3130](http://localhost:8888/lab/tree/ML%20for%20Astro/day3/~/AppData/Local/Programs/Python/Python311/Lib/site-packages/IPython/core/interactiveshell.py#line=3129), in _run_cell()
   3129 try:
-> 3130     result = runner(coro)
   3131 except BaseException as e:

File [~\AppData\Local\Programs\Python\Python311\Lib\site-packages\IPython\core\async_helpers.py:129](http://localhost:8888/lab/tree/ML%20for%20Astro/day3/~/AppData/Local/Programs/Python/Python311/Lib/site-packages/IPython/core/async_helpers.py#line=128), in _pseudo_sync_runner()
    128 try:
--> 129     coro.send(None)
    130 except StopIteration as exc:

File [~\AppData\Local\Programs\Python\Python311\Lib\site-packages\IPython\core\interactiveshell.py:3334](http://localhost:8888/lab/tree/ML%20for%20Astro/day3/~/AppData/Local/Programs/Python/Python311/Lib/site-packages/IPython/core/interactiveshell.py#line=3333), in run_cell_async()
   3331 interactivity = "none" if silent else self.ast_node_interactivity
-> 3334 has_raised = await self.run_ast_nodes(code_ast.body, cell_name,
   3335        interactivity=interactivity, compiler=compiler, result=result)
   3337 self.last_execution_succeeded = not has_raised

File [~\AppData\Local\Programs\Python\Python311\Lib\site-packages\IPython\core\interactiveshell.py:3517](http://localhost:8888/lab/tree/ML%20for%20Astro/day3/~/AppData/Local/Programs/Python/Python311/Lib/site-packages/IPython/core/interactiveshell.py#line=3516), in run_ast_nodes()
   3516     asy = compare(code)
-> 3517 if await self.run_code(code, result, async_=asy):
   3518     return True

File [~\AppData\Local\Programs\Python\Python311\Lib\site-packages\IPython\core\interactiveshell.py:3577](http://localhost:8888/lab/tree/ML%20for%20Astro/day3/~/AppData/Local/Programs/Python/Python311/Lib/site-packages/IPython/core/interactiveshell.py#line=3576), in run_code()
   3576     else:
-> 3577         exec(code_obj, self.user_global_ns, self.user_ns)
   3578 finally:
   3579     # Reset our crash handler in place

Cell In[1], line 32
     31 coeff = 1.
---> 32 grads = jax.grad(fwd_test)(coeff)
     33 # print(grads)

Cell In[1], line 24, in fwd_test()
     23 # Note: larger dt0 so that it runs faster; this is about as large as it can go
---> 24 model_sol = diffeqsolve(ODETerm(test_func), Tsit5(), t0=0, t1=20, dt0=0.5,
     25                     y0=(b0), args=coeff, saveat=SaveAt(ts=model_ts))
     26 model_b = model_sol.ys

File [~\AppData\Local\Programs\Python\Python311\Lib\site-packages\diffrax\integrate.py:823](http://localhost:8888/lab/tree/ML%20for%20Astro/day3/~/AppData/Local/Programs/Python/Python311/Lib/site-packages/diffrax/integrate.py#line=822), in diffeqsolve()
    819 #
    820 # Main loop
    821 #
--> 823 final_state, aux_stats = adjoint.loop(
    824     args=args,
    825     terms=terms,
    826     solver=solver,
    827     stepsize_controller=stepsize_controller,
    828     discrete_terminating_event=discrete_terminating_event,
    829     saveat=saveat,
    830     t0=t0,
    831     t1=t1,
    832     dt0=dt0,
    833     max_steps=max_steps,
    834     init_state=init_state,
    835     throw=throw,
    836     passed_solver_state=passed_solver_state,
    837     passed_controller_state=passed_controller_state,
    838 )
    840 #
    841 # Finish up
    842 #

File [~\AppData\Local\Programs\Python\Python311\Lib\site-packages\diffrax\adjoint.py:286](http://localhost:8888/lab/tree/ML%20for%20Astro/day3/~/AppData/Local/Programs/Python/Python311/Lib/site-packages/diffrax/adjoint.py#line=285), in loop()
    285     msg = None
--> 286 final_state = self._loop(
    287     terms=terms,
    288     saveat=saveat,
    289     init_state=init_state,
    290     max_steps=max_steps,
    291     inner_while_loop=inner_while_loop,
    292     outer_while_loop=outer_while_loop,
    293     **kwargs,
    294 )
    295 if msg is not None:

File [~\AppData\Local\Programs\Python\Python311\Lib\site-packages\diffrax\integrate.py:429](http://localhost:8888/lab/tree/ML%20for%20Astro/day3/~/AppData/Local/Programs/Python/Python311/Lib/site-packages/diffrax/integrate.py#line=428), in loop()
    427 del filter_state
--> 429 final_state = outer_while_loop(
    430     cond_fun, body_fun, init_state, max_steps=max_steps, buffers=_outer_buffers
    431 )
    433 def _save_t1(subsaveat, save_state):

File [~\AppData\Local\Programs\Python\Python311\Lib\contextlib.py:81](http://localhost:8888/lab/tree/ML%20for%20Astro/day3/~/AppData/Local/Programs/Python/Python311/Lib/contextlib.py#line=80), in inner()
     80 with self._recreate_cm():
---> 81     return func(*args, **kwds)

File [~\AppData\Local\Programs\Python\Python311\Lib\site-packages\equinox\internal\_loop\loop.py:107](http://localhost:8888/lab/tree/ML%20for%20Astro/day3/~/AppData/Local/Programs/Python/Python311/Lib/site-packages/equinox/internal/_loop/loop.py#line=106), in while_loop()
    106     del kind, base
--> 107     return checkpointed_while_loop(
    108         cond_fun,
    109         body_fun,
    110         init_val,
    111         max_steps=max_steps,
    112         buffers=buffers,
    113         checkpoints=checkpoints,
    114     )
    115 elif kind == "bounded":

File [~\AppData\Local\Programs\Python\Python311\Lib\site-packages\equinox\internal\_loop\checkpointed.py:247](http://localhost:8888/lab/tree/ML%20for%20Astro/day3/~/AppData/Local/Programs/Python/Python311/Lib/site-packages/equinox/internal/_loop/checkpointed.py#line=246), in checkpointed_while_loop()
    246 cond_fun_ = jtu.tree_map(_stop_gradient, cond_fun_)
--> 247 body_fun_ = filter_closure_convert(body_fun_, init_val_)
    248 vjp_arg = (init_val_, body_fun_)

File [~\AppData\Local\Programs\Python\Python311\Lib\site-packages\equinox\internal\_loop\common.py:463](http://localhost:8888/lab/tree/ML%20for%20Astro/day3/~/AppData/Local/Programs/Python/Python311/Lib/site-packages/equinox/internal/_loop/common.py#line=462), in new_body_fun()
    462 buffer_val = _wrap_buffers(val, pred, tag)
--> 463 buffer_val2 = body_fun(buffer_val)
    464 # Needed to work with `disable_jit`, as then we lose the automatic
    465 # ArrayLike->Array cast provided by JAX's while loops.
    466 # The input `val` is already cast to Array below, so this matches that.

File [~\AppData\Local\Programs\Python\Python311\Lib\site-packages\diffrax\integrate.py:219](http://localhost:8888/lab/tree/ML%20for%20Astro/day3/~/AppData/Local/Programs/Python/Python311/Lib/site-packages/diffrax/integrate.py#line=218), in body_fun()
    214 #
    215 # Actually do some differential equation solving! Make numerical steps, adapt
    216 # step sizes, all that jazz.
    217 #
--> 219 (y, y_error, dense_info, solver_state, solver_result) = solver.step(
    220     terms,
    221     state.tprev,
    222     state.tnext,
    223     state.y,
    224     args,
    225     state.solver_state,
    226     state.made_jump,
    227 )
    229 # e.g. if someone has a sqrt(y) in the vector field, and dt0 is so large that
    230 # we get a negative value for y, and then get a NaN vector field. (And then
    231 # everything breaks.) See #143.

File [~\AppData\Local\Programs\Python\Python311\Lib\site-packages\diffrax\solver\runge_kutta.py:1041](http://localhost:8888/lab/tree/ML%20for%20Astro/day3/~/AppData/Local/Programs/Python/Python311/Lib/site-packages/diffrax/solver/runge_kutta.py#line=1040), in step()
   1035 # Needs to be an `eqxi.while_loop` as:
   1036 # (a) we may have variable length: e.g. an FSAL explicit RK scheme will have one
   1037 #     more stage on the first step.
   1038 # (b) to work around a limitation of JAX's autodiff being unable to express
   1039 #     "triangular computations" (every stage depends on all previous stages)
   1040 #     without spurious copies.
-> 1041 final_val = eqxi.while_loop(
   1042     cond_stage,
   1043     rk_stage,
   1044     init_val,
   1045     max_steps=num_stages,
   1046     buffers=buffers,
   1047     kind="checkpointed" if self.scan_kind is None else self.scan_kind,
   1048     checkpoints=num_stages,
   1049     base=num_stages,
   1050 )
   1051 _, y1, f1_for_fsal, _, _, fs, ks, result = final_val

File [~\AppData\Local\Programs\Python\Python311\Lib\site-packages\equinox\internal\_loop\loop.py:107](http://localhost:8888/lab/tree/ML%20for%20Astro/day3/~/AppData/Local/Programs/Python/Python311/Lib/site-packages/equinox/internal/_loop/loop.py#line=106), in while_loop()
    106     del kind, base
--> 107     return checkpointed_while_loop(
    108         cond_fun,
    109         body_fun,
    110         init_val,
    111         max_steps=max_steps,
    112         buffers=buffers,
    113         checkpoints=checkpoints,
    114     )
    115 elif kind == "bounded":

File [~\AppData\Local\Programs\Python\Python311\Lib\site-packages\equinox\internal\_loop\checkpointed.py:252](http://localhost:8888/lab/tree/ML%20for%20Astro/day3/~/AppData/Local/Programs/Python/Python311/Lib/site-packages/equinox/internal/_loop/checkpointed.py#line=251), in checkpointed_while_loop()
    249 final_val_ = _checkpointed_while_loop(
    250     vjp_arg, cond_fun_, checkpoints, buffers_, max_steps
    251 )
--> 252 _, _, _, final_val = _stop_gradient_on_unperturbed(init_val_, final_val_, body_fun_)
    253 return final_val

JaxStackTraceBeforeTransformation: TypeError: Custom JVP rule must produce primal and tangent outputs with corresponding shapes and dtypes, but got:
  primal int32[] with tangent int32[], expecting tangent ShapedArray(float0[])
  primal bool[] with tangent bool[], expecting tangent ShapedArray(float0[])
  primal bool[] with tangent bool[], expecting tangent ShapedArray(float0[])
  primal int32[] with tangent int32[], expecting tangent ShapedArray(float0[])
  primal int32[] with tangent int32[], expecting tangent ShapedArray(float0[])

The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.

--------------------

The above exception was the direct cause of the following exception:

TypeError                                 Traceback (most recent call last)
Cell In[1], line 32
     28     return jnp.sum((model_b - data_b)**2)
     31 coeff = 1.
---> 32 grads = jax.grad(fwd_test)(coeff)
     33 # print(grads)

    [... skipping hidden 10 frame]

Cell In[1], line 24, in fwd_test(coeff)
     22 model_ts = jnp.linspace(0, 20, num_ts)
     23 # Note: larger dt0 so that it runs faster; this is about as large as it can go
---> 24 model_sol = diffeqsolve(ODETerm(test_func), Tsit5(), t0=0, t1=20, dt0=0.5,
     25                     y0=(b0), args=coeff, saveat=SaveAt(ts=model_ts))
     26 model_b = model_sol.ys
     27 data_b = data_sol.ys

    [... skipping hidden 27 frame]

File [~\AppData\Local\Programs\Python\Python311\Lib\site-packages\equinox\internal\_loop\checkpointed.py:1272](http://localhost:8888/lab/tree/ML%20for%20Astro/day3/~/AppData/Local/Programs/Python/Python311/Lib/site-packages/equinox/internal/_loop/checkpointed.py#line=1271), in _stop_gradient_on_unperturbed_jvp(***failed resolving arguments***)
   1268 del primals, tangents
   1269 perturb_val, perturb_body_fun = jtu.tree_map(
   1270     lambda _, t: t is not None, (init_val, body_fun), (t_init_val, t_body_fun)
   1271 )
-> 1272 perturb_val = _resolve_perturb_val(
   1273     init_val, body_fun, perturb_val, perturb_body_fun
   1274 )
   1275 t_final_val = jtu.tree_map(
   1276     _perturb_to_tang, t_final_val, perturb_val, is_leaf=_is_none
   1277 )
   1278 return final_val, t_final_val

File [~\AppData\Local\Programs\Python\Python311\Lib\site-packages\equinox\internal\_loop\checkpointed.py:1241](http://localhost:8888/lab/tree/ML%20for%20Astro/day3/~/AppData/Local/Programs/Python/Python311/Lib/site-packages/equinox/internal/_loop/checkpointed.py#line=1240), in _resolve_perturb_val(final_val, body_fun, perturb_final_val, perturb_body_fun)
   1238         else:
   1239             perturb_val = jtu.tree_map(operator.or_, perturb_val, new_perturb_val)
-> 1241 perturb_val = jax.eval_shape(_resolve_perturb_val_impl).value
   1242 return perturb_val

    [... skipping hidden 12 frame]

File [~\AppData\Local\Programs\Python\Python311\Lib\site-packages\equinox\internal\_loop\checkpointed.py:1214](http://localhost:8888/lab/tree/ML%20for%20Astro/day3/~/AppData/Local/Programs/Python/Python311/Lib/site-packages/equinox/internal/_loop/checkpointed.py#line=1213), in _resolve_perturb_val.<locals>._resolve_perturb_val_impl()
   1211     return _out
   1213 # Not `jax.jvp`, so as not to error if `body_fun` has any `custom_vjp`s.
-> 1214 jax.linearize(_to_linearize, dynamic)
   1215 if new_perturb_val is sentinel:
   1216     # `_dynamic_out` in `_to_linearize` had no JVP tracers at all, despite
   1217     # `_dynamic` having them. Presumably the user's `_body_fun` has no
   1218     # differentiable dependency whatsoever.
   1219     # This can happen if all the autograd is happening through
   1220     # `perturb_body_fun`.
   1221     return Static(perturb_val)

    [... skipping hidden 5 frame]

File [~\AppData\Local\Programs\Python\Python311\Lib\site-packages\equinox\internal\_loop\checkpointed.py:1207](http://localhost:8888/lab/tree/ML%20for%20Astro/day3/~/AppData/Local/Programs/Python/Python311/Lib/site-packages/equinox/internal/_loop/checkpointed.py#line=1206), in _resolve_perturb_val.<locals>._resolve_perturb_val_impl.<locals>._to_linearize(_dynamic)
   1205 def _to_linearize(_dynamic):
   1206     _body_fun, _val = combine(_dynamic, static)
-> 1207     _out = _body_fun(_val)
   1208     _dynamic_out, _static_out = partition(_out, is_inexact_array)
   1209     _dynamic_out = _record_symbolic_zeros(_dynamic_out)

    [... skipping hidden 10 frame]

File [~\AppData\Local\Programs\Python\Python311\Lib\site-packages\jax\_src\custom_derivatives.py:351](http://localhost:8888/lab/tree/ML%20for%20Astro/day3/~/AppData/Local/Programs/Python/Python311/Lib/site-packages/jax/_src/custom_derivatives.py#line=350), in _flatten_jvp(primal_name, jvp_name, in_tree, maybe_out_type, *args)
    344     msg = ("Custom JVP rule must produce primal and tangent outputs with "
    345            "corresponding shapes and dtypes, but got:\n{}")
    346     disagreements = (
    347         f"  primal {av_p.str_short()} with tangent {av_t.str_short()}, expecting tangent {av_et}"
    348         for av_p, av_et, av_t in zip(primal_avals_out, expected_tangent_avals_out, tangent_avals_out)
    349         if av_et != av_t)
--> 351     raise TypeError(msg.format('\n'.join(disagreements)))
    352 yield primals_out + tangents_out, (out_tree, primal_avals)

TypeError: Custom JVP rule must produce primal and tangent outputs with corresponding shapes and dtypes, but got:
  primal int32[] with tangent int32[], expecting tangent ShapedArray(float0[])
  primal bool[] with tangent bool[], expecting tangent ShapedArray(float0[])
  primal bool[] with tangent bool[], expecting tangent ShapedArray(float0[])
  primal int32[] with tangent int32[], expecting tangent ShapedArray(float0[])
  primal int32[] with tangent int32[], expecting tangent ShapedArray(float0[])

System info (python version, jaxlib version, accelerator, etc.)

jax: 0.4.34
jaxlib: 0.4.34
numpy: 1.26.4
python: 3.11.1 (tags/v3.11.1:a7a450f, Dec 6 2022, 19:58:39) [MSC v.1934 64 bit (AMD64)]
jax.devices (1 total, 1 local): [CpuDevice(id=0)]
process_count: 1
platform: uname_result(system='Windows', release='10', version='10.0.19044', machine='AMD64')

jupyterlab: 4.2.2
diffrax: 0.4.1

@patrick-kidger
Copy link
Owner

This is a known issue that arose in JAX 0.4.34. The tangent types of integers in custom autodiff was changed from matching the primal to instead being a float0.

I've updated Equinox to be compatible in patrick-kidger/equinox#871. I'll do a new release soon. In the mean time you can either install Equinox directly from HEAD, or you can downgrade to JAX 0.4.33.

I hope that helps! :)

(I can see that you said you already tried downgrading. I have just double-checked and Equinox v0.11.7 + JAX 0.4.33 works for me, so I think something else has probably gone wrong for you there. :) )

@patrick-kidger patrick-kidger added the question User queries label Oct 11, 2024
@SnowOwl-Hedwig
Copy link
Author

SnowOwl-Hedwig commented Oct 11, 2024

Don't know what I'm doing wrong here. I just tried equinox 0.11.7 and jax 0.4.33 and still the same issue. Maybe the new release will help ... Fortunately it's not urgend :)

@FFroehlich
Copy link

Just want to note that I get a similar problem

../sdist/amici/jax.py:123: in _solve
    sol = diffrax.diffeqsolve(
../../venv/lib/python3.13/site-packages/equinox/internal/_loop/checkpointed.py:1272: in _stop_gradient_on_unperturbed_jvp
    perturb_val = _resolve_perturb_val(
../../venv/lib/python3.13/site-packages/equinox/internal/_loop/checkpointed.py:1241: in _resolve_perturb_val
    perturb_val = jax.eval_shape(_resolve_perturb_val_impl).value
../../venv/lib/python3.13/site-packages/equinox/internal/_loop/checkpointed.py:1214: in _resolve_perturb_val_impl
    jax.linearize(_to_linearize, dynamic)
../../venv/lib/python3.13/site-packages/equinox/internal/_loop/checkpointed.py:1207: in _to_linearize
    _out = _body_fun(_val)
../../venv/lib/python3.13/site-packages/equinox/internal/_loop/checkpointed.py:1272: in _stop_gradient_on_unperturbed_jvp
    perturb_val = _resolve_perturb_val(
../../venv/lib/python3.13/site-packages/equinox/internal/_loop/checkpointed.py:1241: in _resolve_perturb_val
    perturb_val = jax.eval_shape(_resolve_perturb_val_impl).value
../../venv/lib/python3.13/site-packages/equinox/internal/_loop/checkpointed.py:1214: in _resolve_perturb_val_impl
    jax.linearize(_to_linearize, dynamic)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

_dynamic = (_ClosureConvert(
  jaxpr=None,
  consts=[
    None,
    None,
    None,
    f64[],
    None,
    None,
    None,
    ...m(
    _value=None,
    _enumeration=<class 'optimistix._solution.RESULTS'>
  ),
  step=None
), unused, (None,), ...)))

    def _to_linearize(_dynamic):
        _body_fun, _val = combine(_dynamic, static)
>       _out = _body_fun(_val)
E       TypeError: Custom JVP rule must produce primal and tangent outputs with corresponding shapes and dtypes, but got:
E         primal int64[] with tangent int64[], expecting tangent ShapedArray(float0[])
E         primal int32[] with tangent int32[], expecting tangent ShapedArray(float0[])
E         primal int32[] with tangent int32[], expecting tangent ShapedArray(float0[])
E         primal int64[] with tangent int64[], expecting tangent ShapedArray(float0[])
E       --------------------
E       For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

with jax==0.4.34,jaxlib==0.4.34,diffrax==0.6.0,equinox==0.11.8 on python 3.13.

@FFroehlich
Copy link

Issue above is fixed with optimistix 0.0.9 :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question User queries
Projects
None yet
Development

No branches or pull requests

3 participants