From 0e6fea4d19d8318f840ab064accfd49e35461dac Mon Sep 17 00:00:00 2001 From: Roy Frostig Date: Thu, 2 Mar 2023 23:46:24 +0000 Subject: [PATCH] in `closure_convert`, hoist all tracers involved in staging We need to hoist all `DynamicJaxprTracer`s in `jax.closure_convert`, not only those that are float-dtyped. Leaving any `DynamicJaxprTracer` in the closure of a `custom_{jvp,vjp}`'ed function might result in a tracer leak, as staging to jaxpr could already be complete by the time the closure is consumed. Co-authored-by: Matthew Johnson --- jax/_src/custom_derivatives.py | 15 +++++------- tests/api_test.py | 44 ++++++++++++++++++++++++++++++++-- 2 files changed, 48 insertions(+), 11 deletions(-) diff --git a/jax/_src/custom_derivatives.py b/jax/_src/custom_derivatives.py index 909d79e7e6f7..9850772b26e2 100644 --- a/jax/_src/custom_derivatives.py +++ b/jax/_src/custom_derivatives.py @@ -21,7 +21,6 @@ from jax._src import core from jax._src import custom_api_util from jax._src.custom_transpose import custom_transpose -from jax._src import dtypes from jax._src import effects from jax._src import linear_util as lu from jax._src import traceback_util @@ -1072,7 +1071,7 @@ def rev(objective_fn, res, g): else: return _closure_convert_for_avals(fun, in_tree, in_avals) -def _maybe_perturbed(x: Any) -> bool: +def hoistworthy(x: Any) -> bool: # False if x can't represent an AD-perturbed value (i.e. a value # with a nontrivial tangent attached), up to heuristics, and True otherwise. # See https://github.com/google/jax/issues/6415 for motivation. @@ -1081,14 +1080,12 @@ def _maybe_perturbed(x: Any) -> bool: # If x is not a Tracer, it can't be perturbed. return False elif isinstance(x, pe.DynamicJaxprTracer): - # If x is a DynamicJaxprTracer then we're staging out; differentiation could - # happen later, but some types always have trivial tangents. - vspace = x.aval.at_least_vspace() - return not (vspace is core.abstract_token or - getattr(vspace, 'dtype', None) == dtypes.float0) + # If x is a DynamicJaxprTracer then we're actively staging out. We can't + # keep the tracers involved in staging in closure. That'd be a tracer leak! + return True elif not isinstance(x, ad.JVPTracer): # If x is not a JVPTracer, recursively check its contents. - return any(_maybe_perturbed(attr) for name, attr in x._contents()) + return any(hoistworthy(attr) for name, attr in x._contents()) else: return True # We can't be sure! @@ -1098,7 +1095,7 @@ def _closure_convert_for_avals(fun, in_tree, in_avals): jaxpr, out_pvals, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, in_avals) out_tree = out_tree() - (closure_consts, hoisted_consts), merge = partition_list(_maybe_perturbed, consts) + (closure_consts, hoisted_consts), merge = partition_list(hoistworthy, consts) num_consts = len(hoisted_consts) def converted_fun(*args_hconsts): diff --git a/tests/api_test.py b/tests/api_test.py index f788bb097216..1c8ca68af82b 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -7337,14 +7337,14 @@ def f_jvp(primals, tangents): shape = grad(lambda x: jnp.sum(f(x)))(jnp.array(1.)).shape self.assertEqual(shape, ()) - def test_maybe_perturbed_internal_helper_function(self): + def test_hoisting_internal_helper_function(self): # This is a unit test for an internal API. We include it so as not to # regress https://github.com/google/jax/issues/9567. For an explanation of # this helper function, see https://github.com/google/jax/issues/6415. def f(x): def g(y, _): z = y * x - self.assertTrue(custom_derivatives._maybe_perturbed(z)) + self.assertTrue(custom_derivatives.hoistworthy(z)) return y, None g(1, None) return lax.scan(g, 1, xs=None, length=1)[0] @@ -8456,6 +8456,46 @@ def closure(x): self.assertAllClose(g_c, 42. * c, check_dtypes=False) self.assertAllClose(g_x, 17. * x, check_dtypes=False) + def test_closure_convert_mixed_consts_jitted(self): + def cos_after(fn, x): + converted_fn, aux_args = jax.closure_convert(fn, x) + return _cos_after(converted_fn, x, *aux_args) + + @partial(jax.custom_vjp, nondiff_argnums=(0,)) + def _cos_after(fn, x, *args): + return jnp.cos(fn(x, *args)) + + def fwd(fn, x, *args): + y = _cos_after(fn, x, *args) + return y, (x, args) + + def rev(fn, res, g): + x, args = res + x_bar = 17. * x + args_bars = [42. * a.astype(x.dtype) for a in args] + return (x_bar, *args_bars) + + _cos_after.defvjp(fwd, rev) + + def dist(c, s, i, x): + return jnp.sum(i.astype(x.dtype) * s * (x - c) ** 2.) + + @jax.jit + def solve(c, s, i, x): + def closure(x): + return dist(c, s, i, x) + return cos_after(closure, x) + + c = 2. * jnp.ones(2) + s = 3. * jnp.ones(2) + i = jnp.ones(2, 'int32') + x = jnp.ones(2) + expected = jnp.cos(dist(c, s, i, x)) + self.assertAllClose(solve(c, s, i, x), expected, check_dtypes=False) + g_c, g_x = api.grad(solve, argnums=(0, 3))(c, s, i, x) + self.assertAllClose(g_c, 42. * c, check_dtypes=False) + self.assertAllClose(g_x, 17. * x, check_dtypes=False) + def test_float0_cotangents_automatically_handled(self): @jax.custom_vjp def f(x, y):