-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add axis env state to cache keys, fixes #9187 #9188
Conversation
a34acf6
to
a63993f
Compare
dfa06d2
to
5400420
Compare
f0f597f
to
240079f
Compare
240079f
to
a7bb498
Compare
a7bb498
to
25bfd6f
Compare
"Promise that a function or with-suite does not depend implicitly on axis env" | ||
s = thread_local_state.trace_state | ||
prev_axis_env, s.axis_env = s.axis_env, [] | ||
jax_config.update_thread_local_jit_state(axis_env_state=()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder if it would be possible to put a dummy value that caused an error if you were to actually try to use it here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Such things are possible...
The current approach would also lead to an error, just with a potentially misleading error message (unbound axis name).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added a comment to this effect but didn't add any fancy error mechanisms. Gotta leave some work for the future, right?
15e89d5
to
6bac182
Compare
Subsumed by 11298. |
fixes #9187
This change could cause some new and even unnecessary cache misses. That is, while the cache hits in #9187 were buggy and should be cache misses, to correct them we may end up being overly defensive.
The reason for defensiveness is that in general, even for a fixed sequence of input abstract values to a jitted function, whenever the named axis environment (which represents the mapped-over names bound by
vmap
,pmap
, and/orxmap
) is different from one we've seen before with those input abstract values, we need to re-trace the Python callable. That's because the Python callable could bind a collective likepsum
oraxis_index
, and the result of such a collective depends on the axis environment. (When a mapped axis has no name, i.e. when its name iscore.no_axis_name
, this error can't happen. So for caching purposes unnamed axes can be ignored.)But some Python callables may not involve collectives at all! It'd be a waste to re-trace those. One such instance arose in this PR: the
_multi_slice
function, which is used inpmap
dispatch, was getting re-traced (and re-compiled) even though it binds no collective primitives. So I introduced a low-level way to, in effect, promise that no collectives are bound: usecore.stash_axis_env
to temporarily empty out the axis environment and thus effectively ignore its value.Still, this issue could arise in user code... is the only option to provide such a "promise" API to users?
TODO:
core.stash_axis_env
(tried this and ran into a weird__doc__
issue...)