forked from jax-ml/jax
-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add optional automatic remat optimization to custom_vjp.
As reported in jax-ml#21303, using `remat` with `custom_vjp` can produce inefficient results. The high level summary is that computing the grad of such a function results in the `fwd` function of the `custom_vjp` being evaluated twice, even though the first time the residuals are not actually used. In many cases this isn't a problem because DCE will clean up the unnecessary computations. But, when the fwd function requires an opaque call (e.g. pallas_call or ffi_call), this no longer saves the day. In this PR, I have added a parameter to `custom_vjp` called `optimize_remat` (open for discussion!), which can be used to opt-in to automatic optimization of this operation. Setting this flag to true results in the `fwd` function being wrapped in a new custom primitive which will DCE into a call to the primal function whenever the residuals are unused. This can be used to fix jax-ml#21303, and I think it would make sense to eventually make this behavior the default, but this implementation comes with a few caveats: 1. This feature is currently implemented in "initial style", which means that the `fwd` function is traced to a jaxpr when it is initially called. This means that when `optimize_remat=True`, the `custom_vjp` function doesn't support data dependent conditionals within `fwd`. This isn't a fundamental limitation of the method, but this implementation is much simpler so it seemed like a good place to start, and much of the complexity of the "final style" version of this logic should be simplified by work that @dougalm is doing. Furthermore, for the immediate use case of opaque calls, initial style is not a serious limitation. 2. When `optimize_remat=True`, symbolic zeros are not supported. Again this isn't a required restriction, but I chose to start without this added complexity and we can add support for symbolic zeros as needed in the future. 3. More subtly, while this new primitive supports `vmap`, it doesn't currently implement rules for composing with the AD system. This means that a `custom_vjp` constructed with `optimize_remat=True` won't currently work with some approaches to higher-order AD. I expect I know how to fix that and will either include that here or in a follow-up.
- Loading branch information
Showing
4 changed files
with
332 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.