-
Dear all, I have vmapped a jitted function and would want to call it 100 times. I was wondering if somebody could explain why this happens and if there's any way to avoid this increase in computation time. Best, Warm-up time in seconds: 7.804678440093994 |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
I suspect you're timing the iterations without calling |
Beta Was this translation helpful? Give feedback.
I suspect you're timing the iterations without calling
.block_until_ready()
, so the first few runs are just measuring the dispatch time rather than the actual computation time. See https://jax.readthedocs.io/en/latest/async_dispatch.html and https://jax.readthedocs.io/en/latest/faq.html#benchmarking-jax-code for more information.