Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

"No constant handler" error appearing after jit/pjit merge #15759

Open
DanPuzzuoli opened this issue Apr 26, 2023 · 4 comments
Open

"No constant handler" error appearing after jit/pjit merge #15759

DanPuzzuoli opened this issue Apr 26, 2023 · 4 comments
Assignees
Labels
bug Something isn't working

Comments

@DanPuzzuoli
Copy link
Contributor

Description

This error is discussed in discussion #14561 - just transferring this to an issue.

The self-contained example is below:

import numpy as np
import jax.numpy as jnp
from jax import vmap, jit, grad
from jax.lax import switch
from jax.experimental.ode import odeint

def A0(t):
    return 2.

def A1(a, t):
    return a**2

y0 = np.random.rand(2)
T = np.pi * 1.232

def test_func(a):
    
    eval_list = [A0, lambda t: A1(a, t)]
    
    def single_eval(idx, t):
        return switch(idx, eval_list, t)
    multiple_eval = vmap(single_eval, in_axes=(0, None))
    idx_list = jnp.array([0, 1])
    rhs = lambda y, t: multiple_eval(idx_list, t) * y
    
    # using this version of the RHS works so the vmap seems to be necessary
    #rhs = lambda y, t: jnp.array([single_eval(0, t), single_eval(1, t)]) * y
    
    # using this version also seems to work
    # so issue may be with vmap + switch in this context?
    #rhs = lambda y, t: vmap(lambda x: a * x)(y)
    
    # returning this instead also works, so odeint is necessary?
    #return multiple_eval(idx_list, 1.) * jnp.ones((2, 2), dtype=complex)
    
    out = odeint(
        rhs,
        y0=y0,
        t=jnp.array([0, T], dtype=float),
        atol=1e-13,
        rtol=1e-13
    )
    return out

jit(grad(lambda a: test_func(a)[-1][1].real))(1.)

# evaluation without jitting works
#grad(lambda a: test_func(a)[-1][1].real)(1.)

This results in the error:

TypeError: No constant handler for type: <class 'jax._src.interpreters.partial_eval.DynamicJaxprTracer'>

What jax/jaxlib version are you using?

jax v0.4.8, jaxlib v0.4.7

Which accelerator(s) are you using?

CPU

Additional system info

Mac, python 3.10

NVIDIA GPU info

No response

@DanPuzzuoli DanPuzzuoli added the bug Something isn't working label Apr 26, 2023
@DanPuzzuoli
Copy link
Contributor Author

Hey @mattjj, sorry to bug, but any thoughts on this? I'm starting to run into issues with other packages that assume newer versions of JAX than 0.4.6.

@mattjj
Copy link
Collaborator

mattjj commented Jun 1, 2023

Thanks for the ping. This is a problem with odeint. @froystig and I spent some time working on a fix, and landed some PRs in that direction, but we never finished the fight.

@froystig do you remember the remaining steps we need to accomplish? IIRC we had to land symbolic zeros in custom_vjp, then maybe adjust closure_convert, then change odeint's use of static_argnums? Did we also have to document some changes needed to calling code?

@froystig
Copy link
Member

froystig commented Jun 1, 2023

then maybe adjust closure_convert, then change odeint's use of static_argnums

#14760 is the closure_convert PR. See #14760 (comment) regarding odeint, relatedly.

@DanPuzzuoli
Copy link
Contributor Author

Awesome, thanks. I mainly just wanted to see if it was something still on your radar.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants