diff --git a/jax/_src/custom_derivatives.py b/jax/_src/custom_derivatives.py index 13b9caf3d749..5c97a3dcdff8 100644 --- a/jax/_src/custom_derivatives.py +++ b/jax/_src/custom_derivatives.py @@ -16,7 +16,7 @@ from collections.abc import Callable, Sequence import dataclasses -from functools import update_wrapper, reduce, partial +from functools import update_wrapper, reduce, partial, wraps import inspect from typing import Any, Generic, TypeVar @@ -497,6 +497,7 @@ def __init__(self, self.fwd: Callable[..., tuple[ReturnValue, Any]] | None = None self.bwd: Callable[..., tuple[Any, ...]] | None = None self.symbolic_zeros = False + self.optimize_remat = False __getattr__ = custom_api_util.forward_attr @@ -504,6 +505,7 @@ def defvjp(self, fwd: Callable[..., tuple[ReturnValue, Any]], bwd: Callable[..., tuple[Any, ...]], symbolic_zeros: bool = False, + optimize_remat: bool = False, ) -> None: """Define a custom VJP rule for the function represented by this instance. @@ -560,6 +562,10 @@ def defvjp(self, objects that are given as input leaves to the ``fwd`` rule. Default ``False``. + optimize_remat: boolean, an experimental flag to enable an automatic + optimization when this function is used under :func:`jax.remat`. This + will be most useful when the ``fwd`` rule is an opaque call such as a + Pallas kernel or a custom call. Default ``False``. Returns: None. @@ -582,6 +588,10 @@ def f_bwd(res, g): self.fwd = fwd self.bwd = bwd self.symbolic_zeros = symbolic_zeros + self.optimize_remat = optimize_remat + if self.symbolic_zeros and self.optimize_remat: + raise NotImplementedError( + "remat optimization for custom_vjp does not support symbolic zeros") @traceback_util.api_boundary def __call__(self, *args: Any, **kwargs: Any) -> ReturnValue: # pytype: disable=invalid-annotation @@ -591,6 +601,12 @@ def __call__(self, *args: Any, **kwargs: Any) -> ReturnValue: # pytype: disable raise AttributeError(msg) fwd_name = getattr(self.fwd, '__name__', str(self.fwd)) args = _resolve_kwargs(self.fun, args, kwargs) + if self.optimize_remat: + fwd = optimize_remat_of_custom_vjp_fwd( + self.fun, self.fwd, nondiff_argnums=self.nondiff_argnums, + symbolic_zeros=self.symbolic_zeros) + else: + fwd = self.fwd if config.enable_custom_vjp_by_custom_transpose.value: if self.nondiff_argnums: raise NotImplementedError( @@ -604,16 +620,16 @@ def __call__(self, *args: Any, **kwargs: Any) -> ReturnValue: # pytype: disable f_, dyn_args = argnums_partial(lu.wrap_init(self.fun), dyn_argnums, args, require_static_args_hashable=False) static_args = [args[i] for i in self.nondiff_argnums] - fwd, _ = argnums_partial(lu.wrap_init(self.fwd), dyn_argnums, args, + fwd_, _ = argnums_partial(lu.wrap_init(fwd), dyn_argnums, args, require_static_args_hashable=False) bwd = _add_args(lu.wrap_init(self.bwd), static_args) else: f_, dyn_args = lu.wrap_init(self.fun), args - fwd, bwd = lu.wrap_init(self.fwd), lu.wrap_init(self.bwd) + fwd_, bwd = lu.wrap_init(fwd), lu.wrap_init(self.bwd) args_flat, in_tree = tree_flatten(dyn_args) in_avals = [core.raise_to_shaped(core.get_aval(x)) for x in args_flat] flat_fun, out_type = _flatten_fun_nokwargs(f_, in_tree) - flat_fwd, out_trees = _flatten_fwd(fwd, self.symbolic_zeros, primal_name, + flat_fwd, out_trees = _flatten_fwd(fwd_, self.symbolic_zeros, primal_name, fwd_name, in_tree, out_type) flat_bwd = _flatten_bwd(bwd, in_tree, in_avals, out_trees).call_wrapped out_flat = custom_vjp_call_p.bind(flat_fun, flat_fwd, flat_bwd, @@ -1407,3 +1423,226 @@ def jvp(primals, tangents): # TODO(mattjj): remove these stubs, which exist to avoid breaking internal users custom_jvp_call_jaxpr_p = core.Primitive("custom_jvp_call_jaxpr") + + +# The following is a helper for optimizing the behavior of custom_vjp when used +# under remat. This is really only useful when the `fwd` function to custom_vjp +# executes a black box kernel. Otherwise, DCE will perform this optimization +# automatically. +# +# TODO(dfm): Eventually this should probably be the default behavior for +# custom_vjp, if we can make it so that it is a no-op for most cases. Right now, +# it is written in "initial-style" so it doesn't support eager mode. This was +# a reasonable compromise when written because it made the implementation +# simpler, but it would be worth revisiting this. +def optimize_remat_of_custom_vjp_fwd( + fun: Callable[..., ReturnValue], + fwd: Callable[..., tuple[ReturnValue, Any]], + nondiff_argnums: tuple[int, ...] = (), + symbolic_zeros: bool = False, +) -> Callable[..., tuple[ReturnValue, Any]]: + if symbolic_zeros: + # TODO(dfm): This probably shouldn't be too hard to support. + raise NotImplementedError( + "remat optimization for custom_vjp does not support symbolic zeros") + + @wraps(fwd) + def wrapped_fwd(*args, **kwargs) -> tuple[ReturnValue, Any]: + # TODO(dfm): This initial logic is duplicated from custom_vjp.__call__ + # above and it would be good to consolidate it. + primal_name = getattr(fun, "__name__", str(fun)) + fwd_name = getattr(fwd, "__name__", str(fwd)) + args = _resolve_kwargs(fwd, args, kwargs) + if nondiff_argnums: + for i in nondiff_argnums: _check_for_tracers(args[i]) + nondiff_argnums_ = set(nondiff_argnums) + dyn_argnums = [i for i in range(len(args)) if i not in nondiff_argnums_] + f_, dyn_args = argnums_partial(lu.wrap_init(fun), dyn_argnums, + args, require_static_args_hashable=False) + fwd_, _ = argnums_partial(lu.wrap_init(fwd), dyn_argnums, args, + require_static_args_hashable=False) + else: + f_, dyn_args = lu.wrap_init(fun), args + fwd_ = lu.wrap_init(fwd) + args_flat, in_tree = tree_flatten(dyn_args) + flat_fun, out_type = _flatten_fun_nokwargs(f_, in_tree) + flat_fwd, out_trees = _flatten_fwd(fwd_, False, primal_name, fwd_name, + in_tree, out_type) + + in_avals = [core.raise_to_shaped(core.get_aval(x)) for x in args_flat] + fwd_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_fwd, in_avals) + fwd_jaxpr = pe.close_jaxpr(pe.convert_constvars_jaxpr(fwd_jaxpr)) + prim_tree, res_tree = out_trees() + num_res = res_tree.num_leaves + + if fwd_jaxpr.effects: + raise NotImplementedError( + "remat optimization for custom_vjp does not support forward " + f"functions with side effects, but {fwd_name} has the following " + f"effects: {fwd_jaxpr.effects}") + + @pe._memoize + def fun_jaxpr_thunk(): + jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals) + return jaxpr, consts + + out_flat = remat_opt_p.bind(*consts, *args_flat, + num_consts=len(consts), + num_res=num_res, + fwd_jaxpr=fwd_jaxpr, + fun_jaxpr_thunk=fun_jaxpr_thunk) + res, out_flat = split_list(out_flat, [num_res]) + out_tree = treedef_tuple((prim_tree, res_tree)) + return tree_unflatten(out_tree, (*out_flat, *res)) + + return wrapped_fwd + +def _remat_opt_impl( + *args, + num_consts: int, + num_res: int, + fwd_jaxpr: core.ClosedJaxpr, + fun_jaxpr_thunk: Callable[[], core.ClosedJaxpr], +): + del num_consts, num_res, fun_jaxpr_thunk # unused + return core.jaxpr_as_fun(fwd_jaxpr)(*args) + +def _remat_opt_abstract_eval(*args, fwd_jaxpr: core.ClosedJaxpr, **_): + del args + return fwd_jaxpr.out_avals, fwd_jaxpr.effects + +def _remat_opt_vmap( + spmd_axis_name, axis_size, axis_name, main_type, args, in_dims, + *, + num_consts: int, + num_res: int, + fwd_jaxpr: core.ClosedJaxpr, + fun_jaxpr_thunk: Callable[[], core.ClosedJaxpr], +): + args = [batching.moveaxis(x, d, 0) if d is not not_mapped and d != 0 + else x for x, d in zip(args, in_dims)] + + in_batched = [d is not not_mapped for d in in_dims] + batched_fwd_jaxpr, out_batched = batching.batch_jaxpr( + fwd_jaxpr, axis_size, in_batched, False, + axis_name, spmd_axis_name, main_type) + out_dims = [0 if b else not_mapped for b in out_batched] + + _, prim_batched = split_list(in_batched, [num_consts]) + + @pe._memoize + def batched_fun_jaxpr_thunk(): + fun_jaxpr = core.ClosedJaxpr(*fun_jaxpr_thunk()) + batched_fun_jaxpr, out_batched = batching.batch_jaxpr( + fun_jaxpr, axis_size, prim_batched, False, axis_name, spmd_axis_name, + main_type) + return batched_fun_jaxpr.jaxpr, batched_fun_jaxpr.consts + + batched_outs = remat_opt_p.bind(*args, num_consts=num_consts, + num_res=num_res, + fwd_jaxpr=batched_fwd_jaxpr, + fun_jaxpr_thunk=batched_fun_jaxpr_thunk) + + return batched_outs, out_dims + +def _remat_opt_jvp( + primals, + tangents, + *, + num_consts: int, + num_res: int, + fwd_jaxpr: core.ClosedJaxpr, + fun_jaxpr_thunk: Callable[[], core.ClosedJaxpr], +): + consts, primals = split_list(primals, [num_consts]) + consts_dot, tangents = split_list(tangents, [num_consts]) + # Tangents must be instantated in case we end up DCEing later. + tangents = map(ad.instantiate_zeros, tangents) + consts_nz = [not isinstance(t, Zero) for t in consts_dot] + consts_dot = [c for nz, c in zip(consts_nz, consts_dot) if nz] + in_nz = consts_nz + [True] * len(tangents) + fwd_jaxpr_jvp_, out_nz = ad.jvp_jaxpr(fwd_jaxpr, in_nz, True) + num_out = len(out_nz) - num_res + fwd_jaxpr_jvp_ = ad.rearrange_binders( + fwd_jaxpr_jvp_, [num_consts, len(primals)], + [len(consts_dot), len(tangents)], [num_res, num_out], [num_res, num_out]) + fwd_jaxpr_jvp = pe.close_jaxpr(pe.convert_constvars_jaxpr(fwd_jaxpr_jvp_.jaxpr)) + + @pe._memoize + def fun_jvp_jaxpr_thunk(): + fun_jaxpr = core.ClosedJaxpr(*fun_jaxpr_thunk()) + in_nz = [True] * len(primals) + fun_jvp_jaxpr, _ = ad.jvp_jaxpr(fun_jaxpr, in_nz, True) + return fun_jvp_jaxpr.jaxpr, fun_jvp_jaxpr.consts + + new_num_consts = len(fwd_jaxpr_jvp_.consts) + num_consts + len(consts_dot) + outs = remat_opt_p.bind(*fwd_jaxpr_jvp_.consts, *consts, *consts_dot, + *primals, *tangents, num_consts=new_num_consts, + num_res=2 * num_res, fwd_jaxpr=fwd_jaxpr_jvp, + fun_jaxpr_thunk=fun_jvp_jaxpr_thunk) + res, res_dot, outs, outs_dot = split_list(outs, [num_res, num_res, num_out]) + return (*res, *outs), (*res_dot, *outs_dot) + +def _remat_opt_transpose( + cts, *args, + num_consts: int, + num_res: int, + fwd_jaxpr: core.ClosedJaxpr, + fun_jaxpr_thunk: Callable[[], core.ClosedJaxpr], +): + # TODO(dfm): It shouldn't be too hard to implement this as needed in the + # future. + raise NotImplementedError( + "remat optimization for custom_vjp does not support higher-order AD") + +def _remat_opt_dce(used_outs: list[bool], eqn: core.JaxprEqn): + used_res, used_prims = split_list(used_outs, [eqn.params["num_res"]]) + outvars = [v for used, v in zip(used_outs, eqn.outvars) if used] + if any(used_res): + # If any of the residuals are used, we still need to run fwd at this point, + # but we may end up DCEing again in the future, so we must instantiate all + # the input primals. + instantiate = [False] * eqn.params["num_consts"] + instantiate += [True] * (len(eqn.invars) - eqn.params["num_consts"]) + new_jaxpr, used_ins = pe.dce_jaxpr(eqn.params["fwd_jaxpr"].jaxpr, used_outs, + instantiate=instantiate) + closed_jaxpr = pe.close_jaxpr(new_jaxpr) + invars = [v for used, v in zip(used_ins, eqn.invars) if used] + new_params = dict(eqn.params) + new_num_consts = sum(split_list(used_ins, [eqn.params["num_consts"]])[0]) + new_params["num_consts"] = new_num_consts + new_params["fwd_jaxpr"] = closed_jaxpr + new_params["num_res"] = sum(used_res) + new_eqn = pe.new_jaxpr_eqn( + invars, outvars, remat_opt_p, new_params, closed_jaxpr.effects, + eqn.source_info, eqn.ctx) + return used_ins, new_eqn + else: + # If none of the residuals are used, we run the primal computation instead. + # At this point we drop this custom DCE behavior, but since the primal might + # have different consts than fwd, we build a new JaxprEqn with a closed_call + # primitive. + fun_jaxpr, consts = eqn.params["fun_jaxpr_thunk"]() + new_jaxpr, used_consts, used_ins = pe.dce_jaxpr_consts(fun_jaxpr, used_prims) + consts = [c for used, c in zip(used_consts, consts) if used] + closed_jaxpr = core.ClosedJaxpr(new_jaxpr, consts) + _, invars = split_list(eqn.invars, [eqn.params["num_consts"]]) + invars = [v for used, v in zip(used_ins, invars) if used] + new_eqn = pe.new_jaxpr_eqn( + invars, outvars, core.closed_call_p, dict(call_jaxpr=closed_jaxpr), + closed_jaxpr.effects, eqn.source_info, eqn.ctx) + used_ins = [False] * eqn.params["num_consts"] + used_ins + return used_ins, new_eqn + +remat_opt_p = core.Primitive("remat_opt") +remat_opt_p.multiple_results = True +remat_opt_p.def_impl(_remat_opt_impl) +remat_opt_p.def_effectful_abstract_eval(_remat_opt_abstract_eval) +xla.register_initial_style_primitive(remat_opt_p) +mlir.register_lowering(remat_opt_p, mlir.lower_fun( + _remat_opt_impl, multiple_results=True)) +batching.spmd_axis_primitive_batchers[remat_opt_p] = _remat_opt_vmap +batching.axis_primitive_batchers[remat_opt_p] = partial(_remat_opt_vmap, None) +ad.primitive_jvps[remat_opt_p] = _remat_opt_jvp +ad.primitive_transposes[remat_opt_p] = _remat_opt_transpose +pe.dce_rules[remat_opt_p] = _remat_opt_dce diff --git a/jax/custom_derivatives.py b/jax/custom_derivatives.py index f6b8ff9e94d6..96dc8898fd8e 100644 --- a/jax/custom_derivatives.py +++ b/jax/custom_derivatives.py @@ -30,6 +30,7 @@ custom_vjp_primal_tree_values as custom_vjp_primal_tree_values, CustomVJPPrimal as CustomVJPPrimal, linear_call as linear_call, + remat_opt_p as remat_opt_p, ) from jax._src.ad_util import ( diff --git a/jax/experimental/jax2tf/jax2tf.py b/jax/experimental/jax2tf/jax2tf.py index 87b36f03ca44..545945c91ffd 100644 --- a/jax/experimental/jax2tf/jax2tf.py +++ b/jax/experimental/jax2tf/jax2tf.py @@ -3448,6 +3448,17 @@ def _custom_lin(*args: TfVal, **_) -> Sequence[TfVal]: tf_impl[ad.custom_lin_p] = _custom_lin +def _remat_opt(*args: TfVal, num_consts: int, num_res: int, + fwd_jaxpr: core.ClosedJaxpr, + fun_jaxpr_thunk: Callable) -> Sequence[TfVal]: + del num_consts, num_res, fun_jaxpr_thunk + return _interpret_jaxpr(fwd_jaxpr, *args, extra_name_stack="remat_opt", + fresh_constant_cache=False) + + +tf_impl[custom_derivatives.remat_opt_p] = _remat_opt + + PartitionsOrReplicated = Union[tuple[int, ...], None] def split_to_logical_devices(tensor: TfVal, diff --git a/tests/api_test.py b/tests/api_test.py index 1224ade3358b..f5f8075afd90 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -9653,6 +9653,83 @@ def f_bwd(xy, g): jax.grad(f)(1., 2.) # don't crash + def test_optimize_remat(self): + def fun(x): + # This array is included to make sure that we handle consts appropriately + return np.array([1.0])*x + + def fwd(x): + return np.array([2.0])*x*x/np.array([1.0]), (x,) + + fwd = custom_derivatives.optimize_remat_of_custom_vjp_fwd(fun, fwd) + x = jnp.linspace(0, 5.0, 10) + self.assertAllClose(jax.jit(fwd)(x)[0], 2*x*x) # Shouldn't hit custom DCE + self.assertAllClose(jax.jit(lambda x: fwd(x)[0])(x), x) # Should be DCEed + + def test_optimize_remat_vmap(self): + def fun(x): + return (np.array([1.0])*x)[0] + def fwd(x): + return (np.array([2.0])*x*x/np.array([1.0]))[0], (x,) + fwd = custom_derivatives.optimize_remat_of_custom_vjp_fwd(fun, fwd) + x = jnp.linspace(0, 5.0, 10) + self.assertAllClose(jax.jit(jax.vmap(fwd))(x)[0], 2*x*x) + self.assertAllClose(jax.jit(lambda x: jax.vmap(fwd)(x)[0])(x), x) + + def test_optimize_remat_cond(self): + def fun(x): + return x + def fwd(x): + return x*x, (x,) + fwd = custom_derivatives.optimize_remat_of_custom_vjp_fwd(fun, fwd) + + def g(x): + return jax.lax.cond(True, fwd, lambda x: (2.0 * x, (x,)), x) + x = jnp.linspace(0, 5.0, 10) + self.assertAllClose(jax.jit(g)(x)[0], x*x) + self.assertAllClose(jax.jit(lambda x: g(x)[0])(x), x) + + def test_optimize_remat_jvp(self): + def fun(x): + return x**2 + def fwd_(x): + return x*x, (x,) + fwd = custom_derivatives.optimize_remat_of_custom_vjp_fwd(fun, fwd_) + calc = jax.jvp(fwd, (3.2,), (1.0,)) + expected = jax.jvp(fwd_, (3.2,), (1.0,)) + self.assertAllClose(calc, expected) + + @jax.jit + def g(x, t): + (y, r), (y_dot, r_dot) = jax.jvp(fwd, (x,), (t,)) + return y, y_dot + calc = g(3.2, 1.0) + expected = jax.jvp(fun, (3.2,), (1.0,)) + self.assertAllClose(calc, expected) + + def test_optimize_remat_gh21303(self): + @jax.custom_vjp + def f(x): + return jnp.tan(x) + + def f_fwd(x): + return jnp.sin(x), (x,) + + def f_bwd(res, g): + x, = res + cos_x = jnp.cos(x) + return (cos_x * g,) + + f.defvjp(f_fwd, f_bwd, optimize_remat=True) + + def temp(x): + out = jax.remat(f)(x) + out = out ** 2 + return out + + v, g = jax.value_and_grad(temp)(3.2) + self.assertAllClose(v, jnp.tan(3.2)**2) + def transpose_unary(f, x_example): def transposed(y):