You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Is your feature request related to a problem? Please describe
jax.random.choice seems to be slow especially when sampling without replacement. Sampling with replacement seems to be much faster, but even without replacement is probably slower than jax.random.randint (to be verified).
Describe the solution you'd like
Find a way to use jax.random.choice(replace=False) as little as possible to improve environment speed.
Alternatives considered
As a first study towards this, it turns out that jax.random.choice(..., replace=True) is faster than jax.random.categorical for sampling with replacement. jax.random.choice(..., replace=False) appears much slower than the other two. When sampling without replacement is needed, we still have to study what the best approach is.
Source: notebook
Alternatives to jax.random.choice(..., replace=False) that could be considered and assessed include:
Sampling once from the joint distribution where the joint gathers all the valid pairs
Creating two random partitions p1 and p2, sampling one index i only, and get the two samples by p1[i] and p2[i]
Sequentially sampling one index and then a second one from the conditional distribution given the first one is not available
It may be that jax.random.choice(..., replace=False) ends up being the most optimised version. In any case, the solution may depend on how many samples we need to sample without replacement (e.g. 2 in the case of Snake).
Remarks
We need to take this into account for random policies. It is likely that the random action selection influences the environment speed by a lot, hence biasing speed benchmarks.
The text was updated successfully, but these errors were encountered:
Is your feature request related to a problem? Please describe
jax.random.choice
seems to be slow especially when sampling without replacement. Sampling with replacement seems to be much faster, but even without replacement is probably slower thanjax.random.randint
(to be verified).Describe the solution you'd like
Find a way to use
jax.random.choice(replace=False)
as little as possible to improve environment speed.Alternatives considered
As a first study towards this, it turns out that
jax.random.choice(..., replace=True)
is faster thanjax.random.categorical
for sampling with replacement.jax.random.choice(..., replace=False)
appears much slower than the other two. When sampling without replacement is needed, we still have to study what the best approach is.Source: notebook
Alternatives to
jax.random.choice(..., replace=False)
that could be considered and assessed include:p1[i]
andp2[i]
It may be that
jax.random.choice(..., replace=False)
ends up being the most optimised version. In any case, the solution may depend on how many samples we need to sample without replacement (e.g. 2 in the case of Snake).Remarks
We need to take this into account for random policies. It is likely that the random action selection influences the environment speed by a lot, hence biasing speed benchmarks.
The text was updated successfully, but these errors were encountered: