Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add axis env state to cache keys, fixes #9187 #9188

Closed

Conversation

mattjj
Copy link
Collaborator

@mattjj mattjj commented Jan 13, 2022

fixes #9187

This change could cause some new and even unnecessary cache misses. That is, while the cache hits in #9187 were buggy and should be cache misses, to correct them we may end up being overly defensive.

The reason for defensiveness is that in general, even for a fixed sequence of input abstract values to a jitted function, whenever the named axis environment (which represents the mapped-over names bound by vmap, pmap, and/or xmap) is different from one we've seen before with those input abstract values, we need to re-trace the Python callable. That's because the Python callable could bind a collective like psum or axis_index, and the result of such a collective depends on the axis environment. (When a mapped axis has no name, i.e. when its name is core.no_axis_name, this error can't happen. So for caching purposes unnamed axes can be ignored.)

But some Python callables may not involve collectives at all! It'd be a waste to re-trace those. One such instance arose in this PR: the _multi_slice function, which is used in pmap dispatch, was getting re-traced (and re-compiled) even though it binds no collective primitives. So I introduced a low-level way to, in effect, promise that no collectives are bound: use core.stash_axis_env to temporarily empty out the axis environment and thus effectively ignore its value.

Still, this issue could arise in user code... is the only option to provide such a "promise" API to users?

TODO:

  • wrap (almost) all jax.numpy functions in core.stash_axis_env (tried this and ran into a weird __doc__ issue...)

@mattjj mattjj force-pushed the cache-sensitivity-to-axis-env-state branch from a34acf6 to a63993f Compare January 13, 2022 05:14
@mattjj mattjj force-pushed the cache-sensitivity-to-axis-env-state branch 2 times, most recently from dfa06d2 to 5400420 Compare January 13, 2022 06:04
@mattjj mattjj requested a review from hawkinsp January 13, 2022 06:04
tests/pmap_test.py Outdated Show resolved Hide resolved
@mattjj mattjj force-pushed the cache-sensitivity-to-axis-env-state branch 2 times, most recently from f0f597f to 240079f Compare January 13, 2022 17:20
@mattjj mattjj self-assigned this Jan 14, 2022
jax/_src/util.py Outdated Show resolved Hide resolved
@mattjj mattjj force-pushed the cache-sensitivity-to-axis-env-state branch from 240079f to a7bb498 Compare January 15, 2022 02:00
@mattjj mattjj force-pushed the cache-sensitivity-to-axis-env-state branch from a7bb498 to 25bfd6f Compare January 15, 2022 02:04
jax/_src/config.py Outdated Show resolved Hide resolved
"Promise that a function or with-suite does not depend implicitly on axis env"
s = thread_local_state.trace_state
prev_axis_env, s.axis_env = s.axis_env, []
jax_config.update_thread_local_jit_state(axis_env_state=())
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if it would be possible to put a dummy value that caused an error if you were to actually try to use it here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Such things are possible...

The current approach would also lead to an error, just with a potentially misleading error message (unbound axis name).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added a comment to this effect but didn't add any fancy error mechanisms. Gotta leave some work for the future, right?

@google-ml-butler google-ml-butler bot added kokoro:force-run pull ready Ready for copybara import and testing labels Jan 18, 2022
@mattjj mattjj force-pushed the cache-sensitivity-to-axis-env-state branch from 15e89d5 to 6bac182 Compare February 4, 2022 01:31
@mattjj
Copy link
Collaborator Author

mattjj commented Jul 20, 2022

Subsumed by 11298.

@mattjj mattjj closed this Jul 20, 2022
@mattjj mattjj deleted the cache-sensitivity-to-axis-env-state branch July 20, 2022 22:19
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
pull ready Ready for copybara import and testing
Projects
None yet
Development

Successfully merging this pull request may close these issues.

caching not sensitive to bound axis names not present on inputs (args or closed over)
3 participants