Replies: 1 comment 10 replies
-
Yes, import jax
def f(x, flag):
return x if flag else x + 1
f_jit = jax.jit(f, static_argnums=1)
print(f_jit._cache_size()) # 0
f_jit(1.0, True)
print(f_jit._cache_size()) # 1
f_jit(1.0, False)
print(f_jit._cache_size()) # 2
# re-wrapped function hits the same cache
f_jit_2 = jax.jit(f, static_argnums=1)
print(f_jit_2._cache_size()) # 2
# cache hits don't increase the cache size
f_jit(100.0, True)
f_jit(100.0, False)
print(f_jit._cache_size()) # 2 |
Beta Was this translation helpful? Give feedback.
10 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
I was searching the internet/documentation to see if
jax.jit
retains a cache for previously used static arguments to a function. I couldn't seem to find any information, so I did a quick test. It seems thatjax.jit
does retain the cache.Beta Was this translation helpful? Give feedback.
All reactions