You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
This error is discussed in discussion #14561 - just transferring this to an issue.
The self-contained example is below:
import numpy as np
import jax.numpy as jnp
from jax import vmap, jit, grad
from jax.lax import switch
from jax.experimental.ode import odeint
def A0(t):
return 2.
def A1(a, t):
return a**2
y0 = np.random.rand(2)
T = np.pi * 1.232
def test_func(a):
eval_list = [A0, lambda t: A1(a, t)]
def single_eval(idx, t):
return switch(idx, eval_list, t)
multiple_eval = vmap(single_eval, in_axes=(0, None))
idx_list = jnp.array([0, 1])
rhs = lambda y, t: multiple_eval(idx_list, t) * y
# using this version of the RHS works so the vmap seems to be necessary
#rhs = lambda y, t: jnp.array([single_eval(0, t), single_eval(1, t)]) * y
# using this version also seems to work
# so issue may be with vmap + switch in this context?
#rhs = lambda y, t: vmap(lambda x: a * x)(y)
# returning this instead also works, so odeint is necessary?
#return multiple_eval(idx_list, 1.) * jnp.ones((2, 2), dtype=complex)
out = odeint(
rhs,
y0=y0,
t=jnp.array([0, T], dtype=float),
atol=1e-13,
rtol=1e-13
)
return out
jit(grad(lambda a: test_func(a)[-1][1].real))(1.)
# evaluation without jitting works
#grad(lambda a: test_func(a)[-1][1].real)(1.)
This results in the error:
TypeError: No constant handler for type: <class 'jax._src.interpreters.partial_eval.DynamicJaxprTracer'>
What jax/jaxlib version are you using?
jax v0.4.8, jaxlib v0.4.7
Which accelerator(s) are you using?
CPU
Additional system info
Mac, python 3.10
NVIDIA GPU info
No response
The text was updated successfully, but these errors were encountered:
Thanks for the ping. This is a problem with odeint. @froystig and I spent some time working on a fix, and landed some PRs in that direction, but we never finished the fight.
@froystig do you remember the remaining steps we need to accomplish? IIRC we had to land symbolic zeros in custom_vjp, then maybe adjust closure_convert, then change odeint's use of static_argnums? Did we also have to document some changes needed to calling code?
Description
This error is discussed in discussion #14561 - just transferring this to an issue.
The self-contained example is below:
This results in the error:
What jax/jaxlib version are you using?
jax v0.4.8, jaxlib v0.4.7
Which accelerator(s) are you using?
CPU
Additional system info
Mac, python 3.10
NVIDIA GPU info
No response
The text was updated successfully, but these errors were encountered: