How to deal with an impure reset function #312
Theo-Cheynel
started this conversation in
General
Replies: 1 comment 3 replies
-
Hi @Theo-Cheynel , are you able to do something like this in the env.reset: mocap = jp.ones((10, 3)) * jp.arange(0, 10)[:, None] # load this once
rng = jax.random.PRNGKey(0)
def reset(rng):
rng, key = jax.random.split(rng, 2)
return rng, jax.random.choice(key, mocap, (1,))
rng, val = jax.jit(reset)(rng) |
Beta Was this translation helpful? Give feedback.
3 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
Hi,
I want to train an agent to imitate reference motions from a motion capture dataset. I wrote a custom environment, which works well when the reference motion capture clip is "hardcoded" (when it is the same throughout all environments).
However, I would like to make the reference clip vary across envs, in other words,I want the
env.reset
to pick a random motion clip (at the moment, motion clips are obtained from another class'__getitem__
method).The thing is, that would make the reset function an impure function, because its outputs would differ everytime, and JAX jitting only supports pure functions. Is there a workaround you can think of ?
At the very least, is it possible to tell
ppo.train
not to jit the reset function, while still jitting the step function ?Thanks for your help
Beta Was this translation helpful? Give feedback.
All reactions