-
I'm trying to use a custom import jax
from jax.tree_util import register_pytree_node_class
@register_pytree_node_class
class State:
def __init__(self, t, x):
self.t = t # Scalar
self.x = x # Vector
def tree_flatten(self):
children = (self.t, self.x)
aux_data = None
return (children, aux_data)
@classmethod
def tree_unflatten(cls, aux_data, children):
return cls(*children)
def dsdt(state):
return State(t=1, x=state.t * state.x)
axes = State(t=None, x=0)
vec_dsdt = jax.vmap(dsdt, in_axes=axes, out_axes=axes)
vec_dsdt(State(t=1, x=jnp.array([1, 2.]))) The traceback includes
|
Beta Was this translation helpful? Give feedback.
Answered by
jakevdp
May 23, 2022
Replies: 1 comment
-
You should call it like this: vec_dsdt = jax.vmap(dsdt, in_axes=(axes,), out_axes=axes)
|
Beta Was this translation helpful? Give feedback.
0 replies
Answer selected by
langmore
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
You should call it like this:
in_axes
needs to be a tuple (because in general*args
for a function is a tuple, even if there is a single arg). This is perhaps slightly unclear because for convenience in the common case, JAX allows passing a single integer in place of a length-1 tuple.