Skip to content

Commit

Permalink
in closure_convert, hoist all tracers involved in staging
Browse files Browse the repository at this point in the history
We need to hoist all `DynamicJaxprTracer`s in `jax.closure_convert`,
not only those that are float-dtyped. Leaving any `DynamicJaxprTracer`
in the closure of a `custom_{jvp,vjp}`'ed function might result in a
tracer leak, as staging to jaxpr could already be complete by the time
the closure is consumed.

Co-authored-by: Matthew Johnson <[email protected]>
  • Loading branch information
froystig and mattjj committed May 15, 2023
1 parent 7aefc9a commit 0e6fea4
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 11 deletions.
15 changes: 6 additions & 9 deletions jax/_src/custom_derivatives.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
from jax._src import core
from jax._src import custom_api_util
from jax._src.custom_transpose import custom_transpose
from jax._src import dtypes
from jax._src import effects
from jax._src import linear_util as lu
from jax._src import traceback_util
Expand Down Expand Up @@ -1072,7 +1071,7 @@ def rev(objective_fn, res, g):
else:
return _closure_convert_for_avals(fun, in_tree, in_avals)

def _maybe_perturbed(x: Any) -> bool:
def hoistworthy(x: Any) -> bool:
# False if x can't represent an AD-perturbed value (i.e. a value
# with a nontrivial tangent attached), up to heuristics, and True otherwise.
# See https://github.com/google/jax/issues/6415 for motivation.
Expand All @@ -1081,14 +1080,12 @@ def _maybe_perturbed(x: Any) -> bool:
# If x is not a Tracer, it can't be perturbed.
return False
elif isinstance(x, pe.DynamicJaxprTracer):
# If x is a DynamicJaxprTracer then we're staging out; differentiation could
# happen later, but some types always have trivial tangents.
vspace = x.aval.at_least_vspace()
return not (vspace is core.abstract_token or
getattr(vspace, 'dtype', None) == dtypes.float0)
# If x is a DynamicJaxprTracer then we're actively staging out. We can't
# keep the tracers involved in staging in closure. That'd be a tracer leak!
return True
elif not isinstance(x, ad.JVPTracer):
# If x is not a JVPTracer, recursively check its contents.
return any(_maybe_perturbed(attr) for name, attr in x._contents())
return any(hoistworthy(attr) for name, attr in x._contents())
else:
return True # We can't be sure!

Expand All @@ -1098,7 +1095,7 @@ def _closure_convert_for_avals(fun, in_tree, in_avals):
jaxpr, out_pvals, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, in_avals)
out_tree = out_tree()

(closure_consts, hoisted_consts), merge = partition_list(_maybe_perturbed, consts)
(closure_consts, hoisted_consts), merge = partition_list(hoistworthy, consts)
num_consts = len(hoisted_consts)

def converted_fun(*args_hconsts):
Expand Down
44 changes: 42 additions & 2 deletions tests/api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7337,14 +7337,14 @@ def f_jvp(primals, tangents):
shape = grad(lambda x: jnp.sum(f(x)))(jnp.array(1.)).shape
self.assertEqual(shape, ())

def test_maybe_perturbed_internal_helper_function(self):
def test_hoisting_internal_helper_function(self):
# This is a unit test for an internal API. We include it so as not to
# regress https://github.com/google/jax/issues/9567. For an explanation of
# this helper function, see https://github.com/google/jax/issues/6415.
def f(x):
def g(y, _):
z = y * x
self.assertTrue(custom_derivatives._maybe_perturbed(z))
self.assertTrue(custom_derivatives.hoistworthy(z))
return y, None
g(1, None)
return lax.scan(g, 1, xs=None, length=1)[0]
Expand Down Expand Up @@ -8456,6 +8456,46 @@ def closure(x):
self.assertAllClose(g_c, 42. * c, check_dtypes=False)
self.assertAllClose(g_x, 17. * x, check_dtypes=False)

def test_closure_convert_mixed_consts_jitted(self):
def cos_after(fn, x):
converted_fn, aux_args = jax.closure_convert(fn, x)
return _cos_after(converted_fn, x, *aux_args)

@partial(jax.custom_vjp, nondiff_argnums=(0,))
def _cos_after(fn, x, *args):
return jnp.cos(fn(x, *args))

def fwd(fn, x, *args):
y = _cos_after(fn, x, *args)
return y, (x, args)

def rev(fn, res, g):
x, args = res
x_bar = 17. * x
args_bars = [42. * a.astype(x.dtype) for a in args]
return (x_bar, *args_bars)

_cos_after.defvjp(fwd, rev)

def dist(c, s, i, x):
return jnp.sum(i.astype(x.dtype) * s * (x - c) ** 2.)

@jax.jit
def solve(c, s, i, x):
def closure(x):
return dist(c, s, i, x)
return cos_after(closure, x)

c = 2. * jnp.ones(2)
s = 3. * jnp.ones(2)
i = jnp.ones(2, 'int32')
x = jnp.ones(2)
expected = jnp.cos(dist(c, s, i, x))
self.assertAllClose(solve(c, s, i, x), expected, check_dtypes=False)
g_c, g_x = api.grad(solve, argnums=(0, 3))(c, s, i, x)
self.assertAllClose(g_c, 42. * c, check_dtypes=False)
self.assertAllClose(g_x, 17. * x, check_dtypes=False)

def test_float0_cotangents_automatically_handled(self):
@jax.custom_vjp
def f(x, y):
Expand Down

0 comments on commit 0e6fea4

Please sign in to comment.