Skip to content

Commit

Permalink
Add optional automatic remat optimization to custom_vjp.
Browse files Browse the repository at this point in the history
As reported in jax-ml#21303, using `remat`
with `custom_vjp` can produce inefficient results. The high level
summary is that computing the grad of such a function results in the
`fwd` function of the `custom_vjp` being evaluated twice, even though
the first time the residuals are not actually used. In many cases this
isn't a problem because DCE will clean up the unnecessary computations.
But, when the fwd function requires an opaque call (e.g. pallas_call or
ffi_call), this no longer saves the day.

In this PR, I have added a parameter to `custom_vjp` called
`optimize_remat` (open for discussion!), which can be used to opt-in to
automatic optimization of this operation. Setting this flag to true
results in the `fwd` function being wrapped in a new custom primitive
which will DCE into a call to the primal function whenever the residuals
are unused.

This can be used to fix jax-ml#21303, and
I think it would make sense to eventually make this behavior the
default, but this implementation comes with a few caveats:

1. This feature is currently implemented in "initial style", which means
   that the `fwd` function is traced to a jaxpr when it is initially
   called. This means that when `optimize_remat=True`, the `custom_vjp`
   function doesn't support data dependent conditionals within `fwd`.
   This isn't a fundamental limitation of the method, but this
   implementation is much simpler so it seemed like a good place to
   start, and much of the complexity of the "final style" version of
   this logic should be simplified by work that @dougalm is doing.
   Furthermore, for the immediate use case of opaque calls, initial
   style is not a serious limitation.
2. When `optimize_remat=True`, symbolic zeros are not supported. Again
   this isn't a required restriction, but I chose to start without this
   added complexity and we can add support for symbolic zeros as needed
   in the future.
3. More subtly, while this new primitive supports `vmap`, it doesn't
   currently implement rules for composing with the AD system. This
   means that a `custom_vjp` constructed with `optimize_remat=True`
   won't currently work with some approaches to higher-order AD. I
   expect I know how to fix that and will either include that here or in
   a follow-up.
  • Loading branch information
dfm authored and nitins17 committed Aug 27, 2024
1 parent 2282edc commit d8774a4
Show file tree
Hide file tree
Showing 4 changed files with 332 additions and 4 deletions.
247 changes: 243 additions & 4 deletions jax/_src/custom_derivatives.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from collections.abc import Callable, Sequence
import dataclasses
from functools import update_wrapper, reduce, partial
from functools import update_wrapper, reduce, partial, wraps
import inspect
from typing import Any, Generic, TypeVar

Expand Down Expand Up @@ -497,13 +497,15 @@ def __init__(self,
self.fwd: Callable[..., tuple[ReturnValue, Any]] | None = None
self.bwd: Callable[..., tuple[Any, ...]] | None = None
self.symbolic_zeros = False
self.optimize_remat = False

__getattr__ = custom_api_util.forward_attr

def defvjp(self,
fwd: Callable[..., tuple[ReturnValue, Any]],
bwd: Callable[..., tuple[Any, ...]],
symbolic_zeros: bool = False,
optimize_remat: bool = False,
) -> None:
"""Define a custom VJP rule for the function represented by this instance.
Expand Down Expand Up @@ -560,6 +562,10 @@ def defvjp(self,
objects that are given as input leaves to the ``fwd`` rule.
Default ``False``.
optimize_remat: boolean, an experimental flag to enable an automatic
optimization when this function is used under :func:`jax.remat`. This
will be most useful when the ``fwd`` rule is an opaque call such as a
Pallas kernel or a custom call. Default ``False``.
Returns:
None.
Expand All @@ -582,6 +588,10 @@ def f_bwd(res, g):
self.fwd = fwd
self.bwd = bwd
self.symbolic_zeros = symbolic_zeros
self.optimize_remat = optimize_remat
if self.symbolic_zeros and self.optimize_remat:
raise NotImplementedError(
"remat optimization for custom_vjp does not support symbolic zeros")

@traceback_util.api_boundary
def __call__(self, *args: Any, **kwargs: Any) -> ReturnValue: # pytype: disable=invalid-annotation
Expand All @@ -591,6 +601,12 @@ def __call__(self, *args: Any, **kwargs: Any) -> ReturnValue: # pytype: disable
raise AttributeError(msg)
fwd_name = getattr(self.fwd, '__name__', str(self.fwd))
args = _resolve_kwargs(self.fun, args, kwargs)
if self.optimize_remat:
fwd = optimize_remat_of_custom_vjp_fwd(
self.fun, self.fwd, nondiff_argnums=self.nondiff_argnums,
symbolic_zeros=self.symbolic_zeros)
else:
fwd = self.fwd
if config.enable_custom_vjp_by_custom_transpose.value:
if self.nondiff_argnums:
raise NotImplementedError(
Expand All @@ -604,16 +620,16 @@ def __call__(self, *args: Any, **kwargs: Any) -> ReturnValue: # pytype: disable
f_, dyn_args = argnums_partial(lu.wrap_init(self.fun), dyn_argnums,
args, require_static_args_hashable=False)
static_args = [args[i] for i in self.nondiff_argnums]
fwd, _ = argnums_partial(lu.wrap_init(self.fwd), dyn_argnums, args,
fwd_, _ = argnums_partial(lu.wrap_init(fwd), dyn_argnums, args,
require_static_args_hashable=False)
bwd = _add_args(lu.wrap_init(self.bwd), static_args)
else:
f_, dyn_args = lu.wrap_init(self.fun), args
fwd, bwd = lu.wrap_init(self.fwd), lu.wrap_init(self.bwd)
fwd_, bwd = lu.wrap_init(fwd), lu.wrap_init(self.bwd)
args_flat, in_tree = tree_flatten(dyn_args)
in_avals = [core.raise_to_shaped(core.get_aval(x)) for x in args_flat]
flat_fun, out_type = _flatten_fun_nokwargs(f_, in_tree)
flat_fwd, out_trees = _flatten_fwd(fwd, self.symbolic_zeros, primal_name,
flat_fwd, out_trees = _flatten_fwd(fwd_, self.symbolic_zeros, primal_name,
fwd_name, in_tree, out_type)
flat_bwd = _flatten_bwd(bwd, in_tree, in_avals, out_trees).call_wrapped
out_flat = custom_vjp_call_p.bind(flat_fun, flat_fwd, flat_bwd,
Expand Down Expand Up @@ -1407,3 +1423,226 @@ def jvp(primals, tangents):

# TODO(mattjj): remove these stubs, which exist to avoid breaking internal users
custom_jvp_call_jaxpr_p = core.Primitive("custom_jvp_call_jaxpr")


# The following is a helper for optimizing the behavior of custom_vjp when used
# under remat. This is really only useful when the `fwd` function to custom_vjp
# executes a black box kernel. Otherwise, DCE will perform this optimization
# automatically.
#
# TODO(dfm): Eventually this should probably be the default behavior for
# custom_vjp, if we can make it so that it is a no-op for most cases. Right now,
# it is written in "initial-style" so it doesn't support eager mode. This was
# a reasonable compromise when written because it made the implementation
# simpler, but it would be worth revisiting this.
def optimize_remat_of_custom_vjp_fwd(
fun: Callable[..., ReturnValue],
fwd: Callable[..., tuple[ReturnValue, Any]],
nondiff_argnums: tuple[int, ...] = (),
symbolic_zeros: bool = False,
) -> Callable[..., tuple[ReturnValue, Any]]:
if symbolic_zeros:
# TODO(dfm): This probably shouldn't be too hard to support.
raise NotImplementedError(
"remat optimization for custom_vjp does not support symbolic zeros")

@wraps(fwd)
def wrapped_fwd(*args, **kwargs) -> tuple[ReturnValue, Any]:
# TODO(dfm): This initial logic is duplicated from custom_vjp.__call__
# above and it would be good to consolidate it.
primal_name = getattr(fun, "__name__", str(fun))
fwd_name = getattr(fwd, "__name__", str(fwd))
args = _resolve_kwargs(fwd, args, kwargs)
if nondiff_argnums:
for i in nondiff_argnums: _check_for_tracers(args[i])
nondiff_argnums_ = set(nondiff_argnums)
dyn_argnums = [i for i in range(len(args)) if i not in nondiff_argnums_]
f_, dyn_args = argnums_partial(lu.wrap_init(fun), dyn_argnums,
args, require_static_args_hashable=False)
fwd_, _ = argnums_partial(lu.wrap_init(fwd), dyn_argnums, args,
require_static_args_hashable=False)
else:
f_, dyn_args = lu.wrap_init(fun), args
fwd_ = lu.wrap_init(fwd)
args_flat, in_tree = tree_flatten(dyn_args)
flat_fun, out_type = _flatten_fun_nokwargs(f_, in_tree)
flat_fwd, out_trees = _flatten_fwd(fwd_, False, primal_name, fwd_name,
in_tree, out_type)

in_avals = [core.raise_to_shaped(core.get_aval(x)) for x in args_flat]
fwd_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_fwd, in_avals)
fwd_jaxpr = pe.close_jaxpr(pe.convert_constvars_jaxpr(fwd_jaxpr))
prim_tree, res_tree = out_trees()
num_res = res_tree.num_leaves

if fwd_jaxpr.effects:
raise NotImplementedError(
"remat optimization for custom_vjp does not support forward "
f"functions with side effects, but {fwd_name} has the following "
f"effects: {fwd_jaxpr.effects}")

@pe._memoize
def fun_jaxpr_thunk():
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals)
return jaxpr, consts

out_flat = remat_opt_p.bind(*consts, *args_flat,
num_consts=len(consts),
num_res=num_res,
fwd_jaxpr=fwd_jaxpr,
fun_jaxpr_thunk=fun_jaxpr_thunk)
res, out_flat = split_list(out_flat, [num_res])
out_tree = treedef_tuple((prim_tree, res_tree))
return tree_unflatten(out_tree, (*out_flat, *res))

return wrapped_fwd

def _remat_opt_impl(
*args,
num_consts: int,
num_res: int,
fwd_jaxpr: core.ClosedJaxpr,
fun_jaxpr_thunk: Callable[[], core.ClosedJaxpr],
):
del num_consts, num_res, fun_jaxpr_thunk # unused
return core.jaxpr_as_fun(fwd_jaxpr)(*args)

def _remat_opt_abstract_eval(*args, fwd_jaxpr: core.ClosedJaxpr, **_):
del args
return fwd_jaxpr.out_avals, fwd_jaxpr.effects

def _remat_opt_vmap(
spmd_axis_name, axis_size, axis_name, main_type, args, in_dims,
*,
num_consts: int,
num_res: int,
fwd_jaxpr: core.ClosedJaxpr,
fun_jaxpr_thunk: Callable[[], core.ClosedJaxpr],
):
args = [batching.moveaxis(x, d, 0) if d is not not_mapped and d != 0
else x for x, d in zip(args, in_dims)]

in_batched = [d is not not_mapped for d in in_dims]
batched_fwd_jaxpr, out_batched = batching.batch_jaxpr(
fwd_jaxpr, axis_size, in_batched, False,
axis_name, spmd_axis_name, main_type)
out_dims = [0 if b else not_mapped for b in out_batched]

_, prim_batched = split_list(in_batched, [num_consts])

@pe._memoize
def batched_fun_jaxpr_thunk():
fun_jaxpr = core.ClosedJaxpr(*fun_jaxpr_thunk())
batched_fun_jaxpr, out_batched = batching.batch_jaxpr(
fun_jaxpr, axis_size, prim_batched, False, axis_name, spmd_axis_name,
main_type)
return batched_fun_jaxpr.jaxpr, batched_fun_jaxpr.consts

batched_outs = remat_opt_p.bind(*args, num_consts=num_consts,
num_res=num_res,
fwd_jaxpr=batched_fwd_jaxpr,
fun_jaxpr_thunk=batched_fun_jaxpr_thunk)

return batched_outs, out_dims

def _remat_opt_jvp(
primals,
tangents,
*,
num_consts: int,
num_res: int,
fwd_jaxpr: core.ClosedJaxpr,
fun_jaxpr_thunk: Callable[[], core.ClosedJaxpr],
):
consts, primals = split_list(primals, [num_consts])
consts_dot, tangents = split_list(tangents, [num_consts])
# Tangents must be instantated in case we end up DCEing later.
tangents = map(ad.instantiate_zeros, tangents)
consts_nz = [not isinstance(t, Zero) for t in consts_dot]
consts_dot = [c for nz, c in zip(consts_nz, consts_dot) if nz]
in_nz = consts_nz + [True] * len(tangents)
fwd_jaxpr_jvp_, out_nz = ad.jvp_jaxpr(fwd_jaxpr, in_nz, True)
num_out = len(out_nz) - num_res
fwd_jaxpr_jvp_ = ad.rearrange_binders(
fwd_jaxpr_jvp_, [num_consts, len(primals)],
[len(consts_dot), len(tangents)], [num_res, num_out], [num_res, num_out])
fwd_jaxpr_jvp = pe.close_jaxpr(pe.convert_constvars_jaxpr(fwd_jaxpr_jvp_.jaxpr))

@pe._memoize
def fun_jvp_jaxpr_thunk():
fun_jaxpr = core.ClosedJaxpr(*fun_jaxpr_thunk())
in_nz = [True] * len(primals)
fun_jvp_jaxpr, _ = ad.jvp_jaxpr(fun_jaxpr, in_nz, True)
return fun_jvp_jaxpr.jaxpr, fun_jvp_jaxpr.consts

new_num_consts = len(fwd_jaxpr_jvp_.consts) + num_consts + len(consts_dot)
outs = remat_opt_p.bind(*fwd_jaxpr_jvp_.consts, *consts, *consts_dot,
*primals, *tangents, num_consts=new_num_consts,
num_res=2 * num_res, fwd_jaxpr=fwd_jaxpr_jvp,
fun_jaxpr_thunk=fun_jvp_jaxpr_thunk)
res, res_dot, outs, outs_dot = split_list(outs, [num_res, num_res, num_out])
return (*res, *outs), (*res_dot, *outs_dot)

def _remat_opt_transpose(
cts, *args,
num_consts: int,
num_res: int,
fwd_jaxpr: core.ClosedJaxpr,
fun_jaxpr_thunk: Callable[[], core.ClosedJaxpr],
):
# TODO(dfm): It shouldn't be too hard to implement this as needed in the
# future.
raise NotImplementedError(
"remat optimization for custom_vjp does not support higher-order AD")

def _remat_opt_dce(used_outs: list[bool], eqn: core.JaxprEqn):
used_res, used_prims = split_list(used_outs, [eqn.params["num_res"]])
outvars = [v for used, v in zip(used_outs, eqn.outvars) if used]
if any(used_res):
# If any of the residuals are used, we still need to run fwd at this point,
# but we may end up DCEing again in the future, so we must instantiate all
# the input primals.
instantiate = [False] * eqn.params["num_consts"]
instantiate += [True] * (len(eqn.invars) - eqn.params["num_consts"])
new_jaxpr, used_ins = pe.dce_jaxpr(eqn.params["fwd_jaxpr"].jaxpr, used_outs,
instantiate=instantiate)
closed_jaxpr = pe.close_jaxpr(new_jaxpr)
invars = [v for used, v in zip(used_ins, eqn.invars) if used]
new_params = dict(eqn.params)
new_num_consts = sum(split_list(used_ins, [eqn.params["num_consts"]])[0])
new_params["num_consts"] = new_num_consts
new_params["fwd_jaxpr"] = closed_jaxpr
new_params["num_res"] = sum(used_res)
new_eqn = pe.new_jaxpr_eqn(
invars, outvars, remat_opt_p, new_params, closed_jaxpr.effects,
eqn.source_info, eqn.ctx)
return used_ins, new_eqn
else:
# If none of the residuals are used, we run the primal computation instead.
# At this point we drop this custom DCE behavior, but since the primal might
# have different consts than fwd, we build a new JaxprEqn with a closed_call
# primitive.
fun_jaxpr, consts = eqn.params["fun_jaxpr_thunk"]()
new_jaxpr, used_consts, used_ins = pe.dce_jaxpr_consts(fun_jaxpr, used_prims)
consts = [c for used, c in zip(used_consts, consts) if used]
closed_jaxpr = core.ClosedJaxpr(new_jaxpr, consts)
_, invars = split_list(eqn.invars, [eqn.params["num_consts"]])
invars = [v for used, v in zip(used_ins, invars) if used]
new_eqn = pe.new_jaxpr_eqn(
invars, outvars, core.closed_call_p, dict(call_jaxpr=closed_jaxpr),
closed_jaxpr.effects, eqn.source_info, eqn.ctx)
used_ins = [False] * eqn.params["num_consts"] + used_ins
return used_ins, new_eqn

remat_opt_p = core.Primitive("remat_opt")
remat_opt_p.multiple_results = True
remat_opt_p.def_impl(_remat_opt_impl)
remat_opt_p.def_effectful_abstract_eval(_remat_opt_abstract_eval)
xla.register_initial_style_primitive(remat_opt_p)
mlir.register_lowering(remat_opt_p, mlir.lower_fun(
_remat_opt_impl, multiple_results=True))
batching.spmd_axis_primitive_batchers[remat_opt_p] = _remat_opt_vmap
batching.axis_primitive_batchers[remat_opt_p] = partial(_remat_opt_vmap, None)
ad.primitive_jvps[remat_opt_p] = _remat_opt_jvp
ad.primitive_transposes[remat_opt_p] = _remat_opt_transpose
pe.dce_rules[remat_opt_p] = _remat_opt_dce
1 change: 1 addition & 0 deletions jax/custom_derivatives.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
custom_vjp_primal_tree_values as custom_vjp_primal_tree_values,
CustomVJPPrimal as CustomVJPPrimal,
linear_call as linear_call,
remat_opt_p as remat_opt_p,
)

from jax._src.ad_util import (
Expand Down
11 changes: 11 additions & 0 deletions jax/experimental/jax2tf/jax2tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -3448,6 +3448,17 @@ def _custom_lin(*args: TfVal, **_) -> Sequence[TfVal]:
tf_impl[ad.custom_lin_p] = _custom_lin


def _remat_opt(*args: TfVal, num_consts: int, num_res: int,
fwd_jaxpr: core.ClosedJaxpr,
fun_jaxpr_thunk: Callable) -> Sequence[TfVal]:
del num_consts, num_res, fun_jaxpr_thunk
return _interpret_jaxpr(fwd_jaxpr, *args, extra_name_stack="remat_opt",
fresh_constant_cache=False)


tf_impl[custom_derivatives.remat_opt_p] = _remat_opt


PartitionsOrReplicated = Union[tuple[int, ...], None]

def split_to_logical_devices(tensor: TfVal,
Expand Down
Loading

0 comments on commit d8774a4

Please sign in to comment.