Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

caching not sensitive to bound axis names not present on inputs (args or closed over) #9187

Closed
mattjj opened this issue Jan 13, 2022 · 2 comments
Assignees
Labels
bug Something isn't working

Comments

@mattjj
Copy link
Collaborator

mattjj commented Jan 13, 2022

Here's correct behavior, without jit:

In [3]: def f():
   ...:     return lax.psum(1, 'i')
   ...:

In [6]: jax.vmap(f, axis_name='i', axis_size=2, out_axes=None)()
Out[6]: 2

In [7]: jax.vmap(f, axis_name='i', axis_size=3, out_axes=None)()
Out[7]: 3

But we get incorrect jit cache jits for this kind of computation:

In [8]: f = jax.jit(f)

In [9]: jax.vmap(f, axis_name='i', axis_size=3, out_axes=None)()
Out[9]: DeviceArray(3, dtype=int32, weak_type=True)

In [10]: jax.vmap(f, axis_name='i', axis_size=2, out_axes=None)()
Out[10]: DeviceArray(3, dtype=int32, weak_type=True)

The issue for C++ jit caching is that the ThreadLocalJitState on which caching depends isn't updated for changes to the trace-time named axis environment.

For Python jit caching the same issue happens, since the trace context used as part of the cache key also is not sensitive to the axis env.

The same problem exists for the control flow primitive caches:

In [21]: scanned_f = lambda _, __: (lax.psum(1, 'i'), None)

In [22]: f = lambda: lax.scan(scanned_f, 0, None, length=1)[0]

In [23]: jax.vmap(f, axis_name='i', axis_size=2, out_axes=None)()
Out[23]: DeviceArray(2, dtype=int32, weak_type=True)

In [24]: jax.vmap(f, axis_name='i', axis_size=3, out_axes=None)()
Out[24]: DeviceArray(2, dtype=int32, weak_type=True)

The solution there is similarly to pass in a hashable version of the trace-time axis environment as an argument to the memoized jaxpr-forming functions.

How did we not notice until now? Well, any kind of data dependence (on a value mapped over the named axis in question) makes the caches work. So this seems to be primarily about things like jitting/scanning simple functions involving axis_size('i'), axis_index('i'), or random.normal(key, shape=NamedShape(3, i=4)). Tests involving the latter is what uncovered the issue!

@mattjj mattjj added the bug Something isn't working label Jan 13, 2022
@mattjj mattjj self-assigned this Jan 13, 2022
@mattjj mattjj changed the title caching not sensitive to bound axis names caching not sensitive to bound axis names not present on inputs (args or closed over) Jan 13, 2022
mattjj added a commit to mattjj/jax that referenced this issue Jan 13, 2022
mattjj added a commit to mattjj/jax that referenced this issue Jan 13, 2022
mattjj added a commit to mattjj/jax that referenced this issue Jan 13, 2022
mattjj added a commit to mattjj/jax that referenced this issue Jan 13, 2022
mattjj added a commit to mattjj/jax that referenced this issue Jan 13, 2022
mattjj added a commit to mattjj/jax that referenced this issue Jan 13, 2022
mattjj added a commit to mattjj/jax that referenced this issue Jan 13, 2022
mattjj added a commit to mattjj/jax that referenced this issue Jan 15, 2022
mattjj added a commit to mattjj/jax that referenced this issue Jan 15, 2022
@rajasekharporeddy
Copy link
Contributor

Hi @mattjj

I hope this issue can be closed in favor of #11298. I have tested the provided repro with JAX 0.4.38, it produces the same result with and without JIT.

Without JIT:

>>> import jax
>>> from jax import lax
>>> def f():
...  return lax.psum(1, 'i')
>>> jax.vmap(f, axis_name='i', axis_size=2, out_axes=None)()
2
>>> jax.vmap(f, axis_name='i', axis_size=3, out_axes=None)()
3

With JIT:

>>> f = jax.jit(f)
>>> jax.vmap(f, axis_name='i', axis_size=2, out_axes=None)()
Array(2, dtype=int32, weak_type=True)
>>> jax.vmap(f, axis_name='i', axis_size=3, out_axes=None)()
Array(3, dtype=int32, weak_type=True)

Attaching the colab gist for reference.

Thank you.

@mattjj
Copy link
Collaborator Author

mattjj commented Jan 7, 2025

Thank you!

@mattjj mattjj closed this as completed Jan 7, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
2 participants