How can I jit a vmap function that takes a Pytree containing string "TypeError: Argument 'example' of type <class 'str'> is not a valid JAX type" #24658
-
I am trying to give a pytree to a function and vmap this function, however I can't seem to find a "clean" way to do this. The only way I found so var was to deconstruct the tree, to give its part, but I feel it is not really a clean way to do it. I get the error about TypeError: Argument 'example' of type <class 'str'> is not a valid JAX type Here is a minimal example of what I am trying to achieve:
Thanks ! |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 2 replies
-
Can you try the following: from jax.tree_util import register_dataclass
from dataclasses import dataclass
@partial(register_dataclass, data_fields=['raw_matrices', 'weights'], meta_fields=['type'])
@dataclass
class Kernel:
type: str
raw_matrices: chex.Array # Shape (K, H_k, W_k)
weights: chex.Array # Shape (K,) the docs has more details here |
Beta Was this translation helpful? Give feedback.
-
On a side note, from that line in the OP code # @partial(jit, static_argnames=['kernel']) I think the |
Beta Was this translation helpful? Give feedback.
Can you try the following:
the docs has more details here