Skip to content

Commit

Permalink
cache tracing of (sub)calls when forming a jaxpr
Browse files Browse the repository at this point in the history
  • Loading branch information
mattjj committed Jan 13, 2022
1 parent 2d4e797 commit 8b5a9f5
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 7 deletions.
2 changes: 0 additions & 2 deletions jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -6775,8 +6775,6 @@ def _multi_slice(arr,
start_indices: Tuple[Tuple[int, ...]],
limit_indices: Tuple[Tuple[int, ...]],
removed_dims: Tuple[Tuple[int, ...]]):
print(core.thread_local_state.trace_state.axis_env)
breakpoint()
"""Extracts multiple slices from `arr`.
This is used to shard DeviceArray arguments to pmap. It's implemented as a
Expand Down
20 changes: 18 additions & 2 deletions jax/interpreters/partial_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -1361,8 +1361,13 @@ def process_call(self, call_primitive, f, tracers, params):
in_avals = _tracers_to_avals({}, dim_tracers + tracers)
keep_inputs = [False] * len(dim_tracers) + [True] * len(tracers)
with core.new_sublevel():
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
f, self.main, in_avals, keep_inputs=keep_inputs)
if config.jax_check_tracer_leaks:
# Don't want to keep a strong ref to 'main' in memoization cache key
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
f, self.main, in_avals, keep_inputs=keep_inputs)
else:
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic_memoized(
f, self.main, tuple(in_avals), tuple(keep_inputs)).val
if params.get('inline', False):
return core.eval_jaxpr(jaxpr, consts, *dim_tracers, *tracers)
source_info = source_info_util.current()
Expand Down Expand Up @@ -1603,6 +1608,17 @@ def trace_to_subjaxpr_dynamic(fun: lu.WrappedFun, main: core.MainTrace,
del fun, main, trace, frame, in_tracers, out_tracers, ans
return jaxpr, out_avals, consts

@lu.cache
def trace_to_subjaxpr_dynamic_memoized(
fun: lu.WrappedFun, main: core.MainTrace, in_avals, keep_inputs):
tup = trace_to_subjaxpr_dynamic(fun, main, in_avals, keep_inputs=keep_inputs)
return WrapperForWeakRef(tup)

class WrapperForWeakRef:
val: Any
def __init__(self, val):
self.val = val

@contextlib.contextmanager
def extend_jaxpr_stack(main, frame):
main.jaxpr_stack = main.jaxpr_stack + (frame,)
Expand Down
46 changes: 44 additions & 2 deletions tests/api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1260,7 +1260,6 @@ def f(x, u):
self.assertEqual(fwd(rev(f, 0), 1)(x, u).shape, (5, 2))
self.assertEqual(fwd(fwd(f, 0), 1)(x, u).shape, (5, 2))


def test_large_device_constant(self):
ans = jit(lambda x: 2 * x)(jnp.ones(int(2e6))) # doesn't crash
self.assertAllClose(ans, np.ones(int(2e6)) * 2., check_dtypes=False)
Expand Down Expand Up @@ -3322,6 +3321,49 @@ def test_jnp_array_doesnt_device_put(self):
api.make_jaxpr(lambda: jnp.array(3))()
self.assertEqual(count[0], 0)

def test_subcall_trace_caching(self):
should_be_tracing_f = False

@api.jit
def f(x):
self.assertTrue(should_be_tracing_f)
return x ** 2

@api.jit
def g(x):
nonlocal should_be_tracing_f
self.assertTrue(should_be_tracing_g)
should_be_tracing_f = True
y = f(x)
should_be_tracing_f = False
z = f(x + 1)
return y + z

should_be_tracing_g = True
out = g(2)
self.assertEqual(out, 13)

should_be_tracing_g = False
out = g(3)
self.assertEqual(out, 25)

def test_subcall_jaxpr_id(self):
@api.jit
def f(x):
return x ** 2

def g(x):
y = f(x)
z = f(x + 1)
return y + z

jaxpr = api.make_jaxpr(g)(2)
self.assertIn('call_jaxpr', jaxpr.eqns[0].params)
self.assertIn('call_jaxpr', jaxpr.eqns[2].params)
subjaxpr1 = jaxpr.eqns[0].params['call_jaxpr']
subjaxpr2 = jaxpr.eqns[2].params['call_jaxpr']
self.assertIs(subjaxpr1, subjaxpr2)


class RematTest(jtu.JaxTestCase):

Expand Down Expand Up @@ -3776,7 +3818,7 @@ def g():
return seq[0]

remat(g)()
remat(g)()
remat(lambda: g())() # lambda defeats caching

with self.assertRaisesRegex(UnexpectedTracerError, "global state"):
api.jit(f)()
Expand Down
2 changes: 1 addition & 1 deletion tests/pmap_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1877,7 +1877,7 @@ def f(x):
with jtu.count_jit_and_pmap_compiles() as count: # noqa: F841
_, f_bwd = jax.vjp(f, x)
_ = f_bwd(x)
self.assertEqual(count[0], 2) # one for fwd, one for bwd
self.assertEqual(count[0], 2) # once for fwd, once for bwd

with jtu.count_jit_and_pmap_compiles() as count: # noqa: F841
_, f_bwd2 = jax.vjp(f, x)
Expand Down

0 comments on commit 8b5a9f5

Please sign in to comment.