Skip to content

Commit

Permalink
add axis env state to cache keys, fixes jax-ml#9187
Browse files Browse the repository at this point in the history
  • Loading branch information
mattjj committed Jan 13, 2022
1 parent f0e4f04 commit f0f597f
Show file tree
Hide file tree
Showing 10 changed files with 67 additions and 21 deletions.
27 changes: 14 additions & 13 deletions jax/_src/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import os
import sys
import threading
from typing import Any, List, Callable, NamedTuple, Optional
from typing import Any, List, Callable, NamedTuple, Optional, Hashable
import warnings

from jax._src import lib
Expand Down Expand Up @@ -324,15 +324,8 @@ def validate(new_val):

return _StateContextManager(name, help, update_thread_local_hook, validate)

def _trace_context(self):
"""Returns a tuple of configuration values that affect tracing.
These values are included in the cache key for linear_util.cache.
Values included in this set should also most likely be included in
the C++ JIT state, which is handled separately."""
return (self.x64_enabled, self.jax_numpy_rank_promotion,
self.jax_default_matmul_precision)
def get_thread_local_trace_state(self):
return get_thread_local_trace_state()

class _StateContextManager:
def __init__(self, name, help, update_thread_local_hook,
Expand Down Expand Up @@ -405,7 +398,7 @@ def __setattr__(self, name, val):

class GlobalJitState(NamedTuple):
numpy_rank_promotion: Optional[str] = None
default_matmul_precision: Optional[Any] = None
default_matmul_precision: Optional[Hashable] = None


def update_global_jit_state(**kw):
Expand All @@ -415,9 +408,10 @@ def update_global_jit_state(**kw):


class ThreadLocalJitState(NamedTuple):
dynamic_trace_state: Optional[Any] = None
dynamic_trace_state: Optional[Hashable] = None
axis_env_state: Optional[Hashable] = None
numpy_rank_promotion: Optional[str] = None
default_matmul_precision: Optional[Any] = None
default_matmul_precision: Optional[Hashable] = None


def update_thread_local_jit_state(**kw):
Expand All @@ -426,6 +420,13 @@ def update_thread_local_jit_state(**kw):
tls.extra_jit_context = context._replace(**kw)


def get_thread_local_trace_state() -> Hashable:
tls = jax_jit.thread_local_state()
ctx = tls.extra_jit_context or ThreadLocalJitState()
return (ctx.axis_env_state, config.jax_enable_x64, config.jax_disable_jit,
config.jax_numpy_rank_promotion, config.jax_default_matmul_precision)


# TODO(mattjj): remove all uses of this flag
flags.DEFINE_bool(
'jax_omnistaging',
Expand Down
3 changes: 0 additions & 3 deletions jax/_src/custom_derivatives.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,6 @@ def _initial_style_jaxpr(fun, in_avals):
def _close_jaxpr(jaxpr):
return core.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr), ())

def _initial_style_staging() -> bool:
return core.thread_local_state.trace_state.initial_style

def _sum_tangents(_, x, *xs):
return reduce(ad.add_tangents, xs, x)

Expand Down
3 changes: 3 additions & 0 deletions jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -6769,11 +6769,14 @@ def _compress_method(a, condition, axis=None, out=None):
return compress(condition, a, axis, out)


@core.stash_axis_env()
@partial(jit, static_argnums=(1,2,3))
def _multi_slice(arr,
start_indices: Tuple[Tuple[int, ...]],
limit_indices: Tuple[Tuple[int, ...]],
removed_dims: Tuple[Tuple[int, ...]]):
print(core.thread_local_state.trace_state.axis_env)
breakpoint()
"""Extracts multiple slices from `arr`.
This is used to shard DeviceArray arguments to pmap. It's implemented as a
Expand Down
2 changes: 1 addition & 1 deletion jax/_src/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ def wrapper(*args, **kwargs):
if config.jax_check_tracer_leaks:
return f(*args, **kwargs)
else:
return cached(config._trace_context(), *args, **kwargs)
return cached(config.get_thread_local_trace_state(), *args, **kwargs)

wrapper.cache_clear = cached.cache_clear
wrapper.cache_info = cached.cache_info
Expand Down
20 changes: 20 additions & 0 deletions jax/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1854,20 +1854,40 @@ def _unmap_shaped_array(size: int, axis_name, axis: int, aval: ShapedArray) -> S
def extend_axis_env(axis_name: AxisName, size: int, tag: Any):
frame = AxisEnvFrame(axis_name, size, tag)
thread_local_state.trace_state.axis_env.append(frame)
jax_config.update_thread_local_jit_state(
axis_env_state=tuple(thread_local_state.trace_state.axis_env))
try:
yield
finally:
thread_local_state.trace_state.axis_env.pop()
jax_config.update_thread_local_jit_state(
axis_env_state=tuple(thread_local_state.trace_state.axis_env))

@contextmanager
def extend_axis_env_nd(axes: Iterable[Tuple[AxisName, int]]):
frames = [AxisEnvFrame(axis_name, size, None) for axis_name, size in axes]
thread_local_state.trace_state.axis_env.extend(frames)
jax_config.update_thread_local_jit_state(
axis_env_state=tuple(thread_local_state.trace_state.axis_env))
try:
yield
finally:
for _ in frames:
thread_local_state.trace_state.axis_env.pop()
jax_config.update_thread_local_jit_state(
axis_env_state=tuple(thread_local_state.trace_state.axis_env))

@contextmanager
def stash_axis_env():
"Promise that a function or with-suite does not depend implicitly on axis env"
s = thread_local_state.trace_state
prev_axis_env, s.axis_env = s.axis_env, []
jax_config.update_thread_local_jit_state(axis_env_state=())
try:
yield
finally:
s.axis_env = prev_axis_env
jax_config.update_thread_local_jit_state(axis_env_state=tuple(s.axis_env))


# When a mapped function is given no axis name, we generate a name object based
Expand Down
1 change: 1 addition & 0 deletions jax/interpreters/ad.py
Original file line number Diff line number Diff line change
Expand Up @@ -625,6 +625,7 @@ def out_axes_thunk():

# The freevars are being fanned out (not mapped). During transpose the
# dual of fan-out is fan-in-sum. We apply it to the unmapped invars.
# TODO(mattjj,jekbradbury): should this look at global_axis_size?
assert len(in_axes) == len(arg_cts)
def unmap_zero(zero, in_axis):
return (zero if in_axis is None else
Expand Down
6 changes: 3 additions & 3 deletions jax/linear_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,10 +260,10 @@ def memoized_fun(fun: WrappedFun, *args):
cache = fun_caches.setdefault(fun.f, {})
if config.jax_check_tracer_leaks:
key = (_copy_main_traces(fun.transforms), fun.params, args,
config.x64_enabled, config._trace_context())
config.get_thread_local_trace_state())
else:
key = (fun.transforms, fun.params, args, config.x64_enabled,
config._trace_context())
key = (fun.transforms, fun.params, args,
config.get_thread_local_trace_state())
result = cache.get(key, None)
if result is not None:
ans, stores = result
Expand Down
13 changes: 13 additions & 0 deletions tests/api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -828,6 +828,19 @@ def f(d) -> float:
with self.assertRaisesRegex(TypeError, "'<' not supported.*"):
f({E.A: 1.0, E.B: 2.0})

def test_caches_depend_on_axis_env(self):
# https://github.com/google/jax/issues/9187
f = lambda: lax.psum(1, 'i')
g = jax.jit(f)
expected = jax.vmap(f, axis_name='i', axis_size=2, out_axes=None)()
ans = jax.vmap(g, axis_name='i', axis_size=2, out_axes=None)()
self.assertEqual(ans, expected)

# This second call to g could erroneously get a cache hit.
expected = jax.vmap(f, axis_name='i', axis_size=3, out_axes=None)()
ans = jax.vmap(g, axis_name='i', axis_size=3, out_axes=None)()
self.assertEqual(ans, expected)


class PythonJitTest(CPPJitTest):

Expand Down
10 changes: 10 additions & 0 deletions tests/lax_control_flow_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2908,5 +2908,15 @@ def body(carry):
return lax.while_loop(cond, body, (i, jnp.ones(3)))[1]
jax.vmap(f, in_axes=(0, 1))(jnp.arange(4), jnp.ones((3, 4)))

def test_caches_depend_on_axis_env(self):
# https://github.com/google/jax/issues/9187
scanned_f = lambda _, __: (lax.psum(1, 'i'), None)
f = lambda: lax.scan(scanned_f, 0, None, length=1)[0]
ans = jax.vmap(f, axis_name='i', axis_size=2, out_axes=None)()
self.assertEqual(ans, 2)
ans = jax.vmap(f, axis_name='i', axis_size=3, out_axes=None)()
self.assertEqual(ans, 3)


if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())
3 changes: 2 additions & 1 deletion tests/pmap_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1880,8 +1880,9 @@ def f(x):
self.assertEqual(count[0], 2) # one for fwd, one for bwd

with jtu.count_jit_and_pmap_compiles() as count: # noqa: F841
_ = jax.vjp(f, x)
_, f_bwd2 = jax.vjp(f, x)
_ = f_bwd(x)
_ = f_bwd2(x)
self.assertEqual(count[0], 0) # cache hits on fwd and bwd

@unittest.skipIf(jax._src.lib._xla_extension_version < 44,
Expand Down

0 comments on commit f0f597f

Please sign in to comment.