Skip to content

How can I jit a vmap function that takes a Pytree containing string "TypeError: Argument 'example' of type <class 'str'> is not a valid JAX type" #24658

Answered by ASEM000
thomashirtz asked this question in General
Discussion options

You must be logged in to vote

Can you try the following:

from jax.tree_util import register_dataclass
from dataclasses import dataclass


@partial(register_dataclass, data_fields=['raw_matrices', 'weights'], meta_fields=['type'])
@dataclass
class Kernel:
    type: str
    raw_matrices: chex.Array  # Shape (K, H_k, W_k)
    weights: chex.Array  # Shape (K,)

the docs has more details here

Replies: 2 comments 2 replies

Comment options

You must be logged in to vote
2 replies
@thomashirtz
Comment options

@jakevdp
Comment options

Answer selected by thomashirtz
Comment options

You must be logged in to vote
0 replies
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
3 participants