-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
caching not sensitive to bound axis names not present on inputs (args or closed over) #9187
Labels
bug
Something isn't working
Comments
mattjj
changed the title
caching not sensitive to bound axis names
caching not sensitive to bound axis names not present on inputs (args or closed over)
Jan 13, 2022
mattjj
added a commit
to mattjj/jax
that referenced
this issue
Jan 13, 2022
mattjj
added a commit
to mattjj/jax
that referenced
this issue
Jan 13, 2022
mattjj
added a commit
to mattjj/jax
that referenced
this issue
Jan 13, 2022
1 task
mattjj
added a commit
to mattjj/jax
that referenced
this issue
Jan 13, 2022
mattjj
added a commit
to mattjj/jax
that referenced
this issue
Jan 13, 2022
mattjj
added a commit
to mattjj/jax
that referenced
this issue
Jan 13, 2022
mattjj
added a commit
to mattjj/jax
that referenced
this issue
Jan 13, 2022
mattjj
added a commit
to mattjj/jax
that referenced
this issue
Jan 15, 2022
mattjj
added a commit
to mattjj/jax
that referenced
this issue
Jan 15, 2022
Hi @mattjj I hope this issue can be closed in favor of #11298. I have tested the provided repro with JAX 0.4.38, it produces the same result with and without JIT. Without JIT: >>> import jax
>>> from jax import lax
>>> def f():
... return lax.psum(1, 'i')
>>> jax.vmap(f, axis_name='i', axis_size=2, out_axes=None)()
2
>>> jax.vmap(f, axis_name='i', axis_size=3, out_axes=None)()
3 With JIT: >>> f = jax.jit(f)
>>> jax.vmap(f, axis_name='i', axis_size=2, out_axes=None)()
Array(2, dtype=int32, weak_type=True)
>>> jax.vmap(f, axis_name='i', axis_size=3, out_axes=None)()
Array(3, dtype=int32, weak_type=True) Attaching the colab gist for reference. Thank you. |
Thank you! |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Here's correct behavior, without
jit
:But we get incorrect
jit
cache jits for this kind of computation:The issue for C++ jit caching is that the
ThreadLocalJitState
on which caching depends isn't updated for changes to the trace-time named axis environment.For Python jit caching the same issue happens, since the trace context used as part of the cache key also is not sensitive to the axis env.
The same problem exists for the control flow primitive caches:
The solution there is similarly to pass in a hashable version of the trace-time axis environment as an argument to the memoized jaxpr-forming functions.
How did we not notice until now? Well, any kind of data dependence (on a value mapped over the named axis in question) makes the caches work. So this seems to be primarily about things like jitting/scanning simple functions involving
axis_size('i')
,axis_index('i')
, orrandom.normal(key, shape=NamedShape(3, i=4))
. Tests involving the latter is what uncovered the issue!The text was updated successfully, but these errors were encountered: