diff --git a/jax/_src/config.py b/jax/_src/config.py index c2d174283e47..b0a736394b33 100644 --- a/jax/_src/config.py +++ b/jax/_src/config.py @@ -21,7 +21,7 @@ import os import sys import threading -from typing import Any, List, Callable, NamedTuple, Optional +from typing import Any, List, Callable, NamedTuple, Optional, Hashable import warnings from jax._src import lib @@ -324,15 +324,8 @@ def validate(new_val): return _StateContextManager(name, help, update_thread_local_hook, validate) - def _trace_context(self): - """Returns a tuple of configuration values that affect tracing. - - These values are included in the cache key for linear_util.cache. - - Values included in this set should also most likely be included in - the C++ JIT state, which is handled separately.""" - return (self.x64_enabled, self.jax_numpy_rank_promotion, - self.jax_default_matmul_precision) + def get_thread_local_trace_state(self): + return get_thread_local_trace_state() class _StateContextManager: def __init__(self, name, help, update_thread_local_hook, @@ -405,7 +398,7 @@ def __setattr__(self, name, val): class GlobalJitState(NamedTuple): numpy_rank_promotion: Optional[str] = None - default_matmul_precision: Optional[Any] = None + default_matmul_precision: Optional[Hashable] = None def update_global_jit_state(**kw): @@ -415,9 +408,10 @@ def update_global_jit_state(**kw): class ThreadLocalJitState(NamedTuple): - dynamic_trace_state: Optional[Any] = None + dynamic_trace_state: Optional[Hashable] = None + axis_env_state: Optional[Hashable] = None numpy_rank_promotion: Optional[str] = None - default_matmul_precision: Optional[Any] = None + default_matmul_precision: Optional[Hashable] = None def update_thread_local_jit_state(**kw): @@ -426,6 +420,13 @@ def update_thread_local_jit_state(**kw): tls.extra_jit_context = context._replace(**kw) +def get_thread_local_trace_state() -> Hashable: + tls = jax_jit.thread_local_state() + ctx = tls.extra_jit_context or ThreadLocalJitState() + return (ctx.axis_env_state, config.jax_enable_x64, config.jax_disable_jit, + config.jax_numpy_rank_promotion, config.jax_default_matmul_precision) + + # TODO(mattjj): remove all uses of this flag flags.DEFINE_bool( 'jax_omnistaging', diff --git a/jax/_src/custom_derivatives.py b/jax/_src/custom_derivatives.py index e60110e66e8a..131859795e78 100644 --- a/jax/_src/custom_derivatives.py +++ b/jax/_src/custom_derivatives.py @@ -62,9 +62,6 @@ def _initial_style_jaxpr(fun, in_avals): def _close_jaxpr(jaxpr): return core.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr), ()) -def _initial_style_staging() -> bool: - return core.thread_local_state.trace_state.initial_style - def _sum_tangents(_, x, *xs): return reduce(ad.add_tangents, xs, x) diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 4b86cf72cb84..644f39c4443d 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -6769,11 +6769,14 @@ def _compress_method(a, condition, axis=None, out=None): return compress(condition, a, axis, out) +@core.stash_axis_env() @partial(jit, static_argnums=(1,2,3)) def _multi_slice(arr, start_indices: Tuple[Tuple[int, ...]], limit_indices: Tuple[Tuple[int, ...]], removed_dims: Tuple[Tuple[int, ...]]): + print(core.thread_local_state.trace_state.axis_env) + breakpoint() """Extracts multiple slices from `arr`. This is used to shard DeviceArray arguments to pmap. It's implemented as a diff --git a/jax/_src/util.py b/jax/_src/util.py index a762253c27b6..ad35c1467a3b 100644 --- a/jax/_src/util.py +++ b/jax/_src/util.py @@ -199,7 +199,7 @@ def wrapper(*args, **kwargs): if config.jax_check_tracer_leaks: return f(*args, **kwargs) else: - return cached(config._trace_context(), *args, **kwargs) + return cached(config.get_thread_local_trace_state(), *args, **kwargs) wrapper.cache_clear = cached.cache_clear wrapper.cache_info = cached.cache_info diff --git a/jax/core.py b/jax/core.py index db7fd96d7393..8d717930b3a4 100644 --- a/jax/core.py +++ b/jax/core.py @@ -1854,20 +1854,40 @@ def _unmap_shaped_array(size: int, axis_name, axis: int, aval: ShapedArray) -> S def extend_axis_env(axis_name: AxisName, size: int, tag: Any): frame = AxisEnvFrame(axis_name, size, tag) thread_local_state.trace_state.axis_env.append(frame) + jax_config.update_thread_local_jit_state( + axis_env_state=tuple(thread_local_state.trace_state.axis_env)) try: yield finally: thread_local_state.trace_state.axis_env.pop() + jax_config.update_thread_local_jit_state( + axis_env_state=tuple(thread_local_state.trace_state.axis_env)) @contextmanager def extend_axis_env_nd(axes: Iterable[Tuple[AxisName, int]]): frames = [AxisEnvFrame(axis_name, size, None) for axis_name, size in axes] thread_local_state.trace_state.axis_env.extend(frames) + jax_config.update_thread_local_jit_state( + axis_env_state=tuple(thread_local_state.trace_state.axis_env)) try: yield finally: for _ in frames: thread_local_state.trace_state.axis_env.pop() + jax_config.update_thread_local_jit_state( + axis_env_state=tuple(thread_local_state.trace_state.axis_env)) + +@contextmanager +def stash_axis_env(): + "Promise that a function or with-suite does not depend implicitly on axis env" + s = thread_local_state.trace_state + prev_axis_env, s.axis_env = s.axis_env, [] + jax_config.update_thread_local_jit_state(axis_env_state=()) + try: + yield + finally: + s.axis_env = prev_axis_env + jax_config.update_thread_local_jit_state(axis_env_state=tuple(s.axis_env)) # When a mapped function is given no axis name, we generate a name object based diff --git a/jax/interpreters/ad.py b/jax/interpreters/ad.py index 1a6398aa4443..fbfdf2189c1a 100644 --- a/jax/interpreters/ad.py +++ b/jax/interpreters/ad.py @@ -625,6 +625,7 @@ def out_axes_thunk(): # The freevars are being fanned out (not mapped). During transpose the # dual of fan-out is fan-in-sum. We apply it to the unmapped invars. + # TODO(mattjj,jekbradbury): should this look at global_axis_size? assert len(in_axes) == len(arg_cts) def unmap_zero(zero, in_axis): return (zero if in_axis is None else diff --git a/jax/linear_util.py b/jax/linear_util.py index 397790aec701..df34b675321b 100644 --- a/jax/linear_util.py +++ b/jax/linear_util.py @@ -260,10 +260,10 @@ def memoized_fun(fun: WrappedFun, *args): cache = fun_caches.setdefault(fun.f, {}) if config.jax_check_tracer_leaks: key = (_copy_main_traces(fun.transforms), fun.params, args, - config.x64_enabled, config._trace_context()) + config.get_thread_local_trace_state()) else: - key = (fun.transforms, fun.params, args, config.x64_enabled, - config._trace_context()) + key = (fun.transforms, fun.params, args, + config.get_thread_local_trace_state()) result = cache.get(key, None) if result is not None: ans, stores = result diff --git a/tests/api_test.py b/tests/api_test.py index 692931c25bfc..59fa0ceabd10 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -828,6 +828,19 @@ def f(d) -> float: with self.assertRaisesRegex(TypeError, "'<' not supported.*"): f({E.A: 1.0, E.B: 2.0}) + def test_caches_depend_on_axis_env(self): + # https://github.com/google/jax/issues/9187 + f = lambda: lax.psum(1, 'i') + g = jax.jit(f) + expected = jax.vmap(f, axis_name='i', axis_size=2, out_axes=None)() + ans = jax.vmap(g, axis_name='i', axis_size=2, out_axes=None)() + self.assertEqual(ans, expected) + + # This second call to g could erroneously get a cache hit. + expected = jax.vmap(f, axis_name='i', axis_size=3, out_axes=None)() + ans = jax.vmap(g, axis_name='i', axis_size=3, out_axes=None)() + self.assertEqual(ans, expected) + class PythonJitTest(CPPJitTest): diff --git a/tests/lax_control_flow_test.py b/tests/lax_control_flow_test.py index d32bce80f4b6..26561eb23ee2 100644 --- a/tests/lax_control_flow_test.py +++ b/tests/lax_control_flow_test.py @@ -2908,5 +2908,15 @@ def body(carry): return lax.while_loop(cond, body, (i, jnp.ones(3)))[1] jax.vmap(f, in_axes=(0, 1))(jnp.arange(4), jnp.ones((3, 4))) + def test_caches_depend_on_axis_env(self): + # https://github.com/google/jax/issues/9187 + scanned_f = lambda _, __: (lax.psum(1, 'i'), None) + f = lambda: lax.scan(scanned_f, 0, None, length=1)[0] + ans = jax.vmap(f, axis_name='i', axis_size=2, out_axes=None)() + self.assertEqual(ans, 2) + ans = jax.vmap(f, axis_name='i', axis_size=3, out_axes=None)() + self.assertEqual(ans, 3) + + if __name__ == '__main__': absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/pmap_test.py b/tests/pmap_test.py index 5e8db1e5d304..37a7c8ee5b49 100644 --- a/tests/pmap_test.py +++ b/tests/pmap_test.py @@ -1880,8 +1880,9 @@ def f(x): self.assertEqual(count[0], 2) # one for fwd, one for bwd with jtu.count_jit_and_pmap_compiles() as count: # noqa: F841 - _ = jax.vjp(f, x) + _, f_bwd2 = jax.vjp(f, x) _ = f_bwd(x) + _ = f_bwd2(x) self.assertEqual(count[0], 0) # cache hits on fwd and bwd @unittest.skipIf(jax._src.lib._xla_extension_version < 44,