From 8b5a9f5916121db6cb40ce021f63012b818be07c Mon Sep 17 00:00:00 2001 From: Matthew Johnson Date: Wed, 12 Jan 2022 13:14:02 -0800 Subject: [PATCH] cache tracing of (sub)calls when forming a jaxpr --- jax/_src/numpy/lax_numpy.py | 2 -- jax/interpreters/partial_eval.py | 20 ++++++++++++-- tests/api_test.py | 46 ++++++++++++++++++++++++++++++-- tests/pmap_test.py | 2 +- 4 files changed, 63 insertions(+), 7 deletions(-) diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 644f39c4443d..843a1c7f2cd1 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -6775,8 +6775,6 @@ def _multi_slice(arr, start_indices: Tuple[Tuple[int, ...]], limit_indices: Tuple[Tuple[int, ...]], removed_dims: Tuple[Tuple[int, ...]]): - print(core.thread_local_state.trace_state.axis_env) - breakpoint() """Extracts multiple slices from `arr`. This is used to shard DeviceArray arguments to pmap. It's implemented as a diff --git a/jax/interpreters/partial_eval.py b/jax/interpreters/partial_eval.py index 3e35b568852a..18a52cf32e7e 100644 --- a/jax/interpreters/partial_eval.py +++ b/jax/interpreters/partial_eval.py @@ -1361,8 +1361,13 @@ def process_call(self, call_primitive, f, tracers, params): in_avals = _tracers_to_avals({}, dim_tracers + tracers) keep_inputs = [False] * len(dim_tracers) + [True] * len(tracers) with core.new_sublevel(): - jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic( - f, self.main, in_avals, keep_inputs=keep_inputs) + if config.jax_check_tracer_leaks: + # Don't want to keep a strong ref to 'main' in memoization cache key + jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic( + f, self.main, in_avals, keep_inputs=keep_inputs) + else: + jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic_memoized( + f, self.main, tuple(in_avals), tuple(keep_inputs)).val if params.get('inline', False): return core.eval_jaxpr(jaxpr, consts, *dim_tracers, *tracers) source_info = source_info_util.current() @@ -1603,6 +1608,17 @@ def trace_to_subjaxpr_dynamic(fun: lu.WrappedFun, main: core.MainTrace, del fun, main, trace, frame, in_tracers, out_tracers, ans return jaxpr, out_avals, consts +@lu.cache +def trace_to_subjaxpr_dynamic_memoized( + fun: lu.WrappedFun, main: core.MainTrace, in_avals, keep_inputs): + tup = trace_to_subjaxpr_dynamic(fun, main, in_avals, keep_inputs=keep_inputs) + return WrapperForWeakRef(tup) + +class WrapperForWeakRef: + val: Any + def __init__(self, val): + self.val = val + @contextlib.contextmanager def extend_jaxpr_stack(main, frame): main.jaxpr_stack = main.jaxpr_stack + (frame,) diff --git a/tests/api_test.py b/tests/api_test.py index 59fa0ceabd10..9cdec48e0941 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -1260,7 +1260,6 @@ def f(x, u): self.assertEqual(fwd(rev(f, 0), 1)(x, u).shape, (5, 2)) self.assertEqual(fwd(fwd(f, 0), 1)(x, u).shape, (5, 2)) - def test_large_device_constant(self): ans = jit(lambda x: 2 * x)(jnp.ones(int(2e6))) # doesn't crash self.assertAllClose(ans, np.ones(int(2e6)) * 2., check_dtypes=False) @@ -3322,6 +3321,49 @@ def test_jnp_array_doesnt_device_put(self): api.make_jaxpr(lambda: jnp.array(3))() self.assertEqual(count[0], 0) + def test_subcall_trace_caching(self): + should_be_tracing_f = False + + @api.jit + def f(x): + self.assertTrue(should_be_tracing_f) + return x ** 2 + + @api.jit + def g(x): + nonlocal should_be_tracing_f + self.assertTrue(should_be_tracing_g) + should_be_tracing_f = True + y = f(x) + should_be_tracing_f = False + z = f(x + 1) + return y + z + + should_be_tracing_g = True + out = g(2) + self.assertEqual(out, 13) + + should_be_tracing_g = False + out = g(3) + self.assertEqual(out, 25) + + def test_subcall_jaxpr_id(self): + @api.jit + def f(x): + return x ** 2 + + def g(x): + y = f(x) + z = f(x + 1) + return y + z + + jaxpr = api.make_jaxpr(g)(2) + self.assertIn('call_jaxpr', jaxpr.eqns[0].params) + self.assertIn('call_jaxpr', jaxpr.eqns[2].params) + subjaxpr1 = jaxpr.eqns[0].params['call_jaxpr'] + subjaxpr2 = jaxpr.eqns[2].params['call_jaxpr'] + self.assertIs(subjaxpr1, subjaxpr2) + class RematTest(jtu.JaxTestCase): @@ -3776,7 +3818,7 @@ def g(): return seq[0] remat(g)() - remat(g)() + remat(lambda: g())() # lambda defeats caching with self.assertRaisesRegex(UnexpectedTracerError, "global state"): api.jit(f)() diff --git a/tests/pmap_test.py b/tests/pmap_test.py index 37a7c8ee5b49..d035099fcdb3 100644 --- a/tests/pmap_test.py +++ b/tests/pmap_test.py @@ -1877,7 +1877,7 @@ def f(x): with jtu.count_jit_and_pmap_compiles() as count: # noqa: F841 _, f_bwd = jax.vjp(f, x) _ = f_bwd(x) - self.assertEqual(count[0], 2) # one for fwd, one for bwd + self.assertEqual(count[0], 2) # once for fwd, once for bwd with jtu.count_jit_and_pmap_compiles() as count: # noqa: F841 _, f_bwd2 = jax.vjp(f, x)