-
Notifications
You must be signed in to change notification settings - Fork 62
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #363 from JuliaDiff/ox/config
RuleConfigs (include for calling back into AD)
- Loading branch information
Showing
10 changed files
with
419 additions
and
12 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,123 @@ | ||
# [Rule configurations and calling back into AD](@id config) | ||
|
||
[`RuleConfig`](@ref) is a method for making rules conditionally defined based on the presence of certain features in the AD system. | ||
One key such feature is the ability to perform AD either in forwards or reverse mode or both. | ||
|
||
This is done with a trait-like system (not Holy Traits), where the `RuleConfig` has a union of types as its only type-parameter. | ||
Where each type represents a particular special feature of this AD. | ||
To indicate that the AD system has a special property, its `RuleConfig` should be defined as: | ||
```julia | ||
struct MyADRuleConfig <: RuleConfig{Union{Feature1, Feature2}} end | ||
``` | ||
And rules that should only be defined when an AD has a particular special property write: | ||
```julia | ||
rrule(::RuleConfig{>:Feature1}, f, args...) = # rrule that should only be define for ADs with `Feature1` | ||
|
||
frule(::RuleConfig{>:Union{Feature1,Feature2}}, f, args...) = # frule that should only be define for ADs with both `Feature1` and `Feature2` | ||
``` | ||
|
||
A prominent use of this is in declaring that the AD system can, or cannot support being called from within the rule definitions. | ||
|
||
## Declaring support for calling back into ADs | ||
|
||
To declare support or lack of support for forward and reverse-mode, use the two pairs of complementary types. | ||
For reverse mode: [`HasReverseMode`](@ref), [`NoReverseMode`](@ref). | ||
For forwards mode: [`HasForwardsMode`](@ref), [`NoForwardsMode`](@ref). | ||
AD systems that support any calling back into AD should have one from each set. | ||
|
||
If an AD `HasReverseMode`, then it must define [`rrule_via_ad`](@ref) for that RuleConfig subtype. | ||
Similarly, if an AD `HasForwardsMode` then it must define [`frule_via_ad`](@ref) for that RuleConfig subtype. | ||
|
||
For example: | ||
```julia | ||
struct MyReverseOnlyADRuleConfig <: RuleConfig{Union{HasReverseMode, NoForwardsMode}} end | ||
|
||
function ChainRulesCore.rrule_via_ad(::MyReverseOnlyADRuleConfig, f, args...) | ||
... | ||
return y, pullback | ||
end | ||
``` | ||
|
||
Note that it is not actually required that the same AD is used for forward and reverse. | ||
For example [Nabla.jl](https://github.com/invenia/Nabla.jl/) is a reverse mode AD. | ||
It might declare that it `HasForwardsMode`, and then define a wrapper around [ForwardDiff.jl](https://github.com/JuliaDiff/ForwardDiff.jl) in order to provide that capacity. | ||
|
||
## Writing rules that call back into AD | ||
|
||
To define e.g. rules for higher order functions, it is useful to be able to call back into the AD system to get it to do some work for you. | ||
|
||
For example the rule for reverse mode AD for `map` might like to use forward mode AD if one is available. | ||
Particularly for the case where only a single input collection is being mapped over. | ||
In that case we know the most efficient way to compute that sub-program is in forwards, as each call with-in the map only takes a single input. | ||
|
||
Note: the following is not the most efficient rule for `map` via forward, but attempts to be clearer for demonstration purposes. | ||
|
||
```julia | ||
function rrule(config::RuleConfig{>:HasForwardsMode}, ::typeof(map), f::Function, x::Array{<:Real}) | ||
# real code would support functors/closures, but in interest of keeping example short we exclude it: | ||
@assert (fieldcount(typeof(f)) == 0) "Functors/Closures are not supported" | ||
|
||
y_and_ẏ = map(x) do xi | ||
frule_via_ad(config, (NoTangent(), one(xi)), f, xi) | ||
end | ||
y = first.(y_and_ẏ) | ||
ẏ = last.(y_and_ẏ) | ||
|
||
pullback_map(ȳ) = NoTangent(), NoTangent(), ȳ .* ẏ | ||
return y, pullback_map | ||
end | ||
``` | ||
|
||
## Writing rules that depend on other special requirements of the AD. | ||
|
||
The `>:HasReverseMode` and `>:HasForwardsMode` are two examples of special properties that a `RuleConfig` could allow. | ||
Others could also exist, but right now they are the only two. | ||
It is likely that in the future such will be provided for e.g. mutation support. | ||
|
||
Such a thing would look like: | ||
```julia | ||
struct SupportsMutation end | ||
|
||
function rrule( | ||
::RuleConfig{>:SupportsMutatation}, typeof(push!), x::Vector | ||
) | ||
y = push!(x) | ||
|
||
function push!_pullback(ȳ) | ||
pop!(x) # undo change to primal incase it is used in another pullback we haven't called yet | ||
pop!(ȳ) # accumulate gradient via mutating ȳ, then return ZeroTangent | ||
return NoTangent(), ZeroTangent() | ||
end | ||
|
||
return y, push!_pullback | ||
end | ||
``` | ||
and it would be used in the AD e.g. as follows: | ||
```julia | ||
struct EnzymeRuleConfig <: RuleConfig{Union{SupportsMutation, HasReverseMode, NoForwardsMode}} | ||
``` | ||
|
||
Note: you can only depend on the presence of a feature, not its absence. | ||
This means we may need to define features and their compliments, when one is not the obvious default (as in the fast of [`HasReverseMode`](@ref)/[`NoReverseMode`](@ref) and [`HasForwardsMode`](@ref)/[`NoForwardsMode`](@ref).). | ||
|
||
|
||
Such special properties generally should only be defines in `ChainRulesCore`. | ||
(Theoretically, they could be defined elsewhere, but the AD and the package containing the rule need to load them, and ChainRulesCore is the place for things like that.) | ||
|
||
|
||
## Writing rules that are only for your own AD | ||
|
||
A special case of the above is writing rules that are defined only for your own AD. | ||
Rules which otherwise would be type-piracy, and would affect other AD systems. | ||
This could be done via making up a special property type and dispatching on it. | ||
But there is no need, as we can dispatch on the `RuleConfig` subtype directly. | ||
|
||
For example in order to avoid mutation in nested AD situations, Zygote might want to have a rule for [`add!!`](@ref) that makes it just do `+`. | ||
|
||
```julia | ||
struct ZygoteConfig <: RuleConfig{Union{}} end | ||
|
||
rrule(::ZygoteConfig, typeof(ChainRulesCore.add!!), a, b) = a+b, Δ->(NoTangent(), Δ, Δ) | ||
``` | ||
|
||
As an alternative to rules only for one AD, would be to add new special property definitions to ChainRulesCore (as described above) which would capture what makes that AD special. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
""" | ||
RuleConfig{T} | ||
The configuration for what rules to use. | ||
`T`: **traits**. This should be a `Union` of all special traits needed for rules to be | ||
allowed to be defined for your AD. If nothing special this should be set to `Union{}`. | ||
**AD authors** should define a subtype of `RuleConfig` to use when calling `frule`/`rrule`. | ||
**Rule authors** can dispatch on this config when defining rules. | ||
For example: | ||
```julia | ||
# only define rrule for `pop!` on AD systems where mutation is supported. | ||
rrule(::RuleConfig{>:SupportsMutation}, typeof(pop!), ::Vector) = ... | ||
# this definition of map is for any AD that defines a forwards mode | ||
rrule(conf::RuleConfig{>:HasForwardsMode}, typeof(map), ::Vector) = ... | ||
# this definition of map is for any AD that only defines a reverse mode. | ||
# It is not as good as the rrule that can be used if the AD defines a forward-mode as well. | ||
rrule(conf::RuleConfig{>:Union{NoForwardsMode, HasReverseMode}}, typeof(map), ::Vector) = ... | ||
``` | ||
For more details see [rule configurations and calling back into AD](@ref config). | ||
""" | ||
abstract type RuleConfig{T} end | ||
|
||
# Broadcast like a scalar | ||
Base.Broadcast.broadcastable(config::RuleConfig) = Ref(config) | ||
|
||
abstract type ReverseModeCapability end | ||
|
||
""" | ||
HasReverseMode | ||
This trait indicates that a `RuleConfig{>:HasReverseMode}` can perform reverse mode AD. | ||
If it is set then [`rrule_via_ad`](@ref) must be implemented. | ||
""" | ||
struct HasReverseMode <: ReverseModeCapability end | ||
|
||
""" | ||
NoReverseMode | ||
This is the complement to [`HasReverseMode`](@ref). To avoid ambiguities [`RuleConfig`]s | ||
that do not support performing reverse mode AD should be `RuleConfig{>:NoReverseMode}`. | ||
""" | ||
struct NoReverseMode <: ReverseModeCapability end | ||
|
||
abstract type ForwardsModeCapability end | ||
|
||
""" | ||
HasForwardsMode | ||
This trait indicates that a `RuleConfig{>:HasForwardsMode}` can perform forward mode AD. | ||
If it is set then [`frule_via_ad`](@ref) must be implemented. | ||
""" | ||
struct HasForwardsMode <: ForwardsModeCapability end | ||
|
||
""" | ||
NoForwardsMode | ||
This is the complement to [`HasForwardsMode`](@ref). To avoid ambiguities [`RuleConfig`]s | ||
that do not support performing forwards mode AD should be `RuleConfig{>:NoForwardsMode}`. | ||
""" | ||
struct NoForwardsMode <: ForwardsModeCapability end | ||
|
||
|
||
""" | ||
frule_via_ad(::RuleConfig{>:HasForwardsMode}, ȧrgs, f, args...; kwargs...) | ||
This function has the same API as [`frule`](@ref), but operates via performing forwards mode | ||
automatic differentiation. | ||
Any `RuleConfig` subtype that supports the [`HasForwardsMode`](@ref) special feature must | ||
provide an implementation of it. | ||
See also: [`rrule_via_ad`](@ref), [`RuleConfig`](@ref) and the documentation on | ||
[rule configurations and calling back into AD](@ref config) | ||
""" | ||
function frule_via_ad end | ||
|
||
""" | ||
rrule_via_ad(::RuleConfig{>:HasReverseMode}, f, args...; kwargs...) | ||
This function has the same API as [`rrule`](@ref), but operates via performing reverse mode | ||
automatic differentiation. | ||
Any `RuleConfig` subtype that supports the [`HasReverseMode`](@ref) special feature must | ||
provide an implementation of it. | ||
See also: [`frule_via_ad`](@ref), [`RuleConfig`](@ref) and the documentation on | ||
[rule configurations and calling back into AD](@ref config) | ||
""" | ||
function rrule_via_ad end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
7947bed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@JuliaRegistrator register
7947bed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Registration pull request created: JuliaRegistries/General/38665
After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.
This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via: