Skip to content

Commit

Permalink
Merge pull request #363 from JuliaDiff/ox/config
Browse files Browse the repository at this point in the history
RuleConfigs (include for calling back into AD)
  • Loading branch information
oxinabox authored Jun 11, 2021
2 parents 7d667c5 + 87fe59e commit 7947bed
Show file tree
Hide file tree
Showing 10 changed files with 419 additions and 12 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "ChainRulesCore"
uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
version = "0.10.3"
version = "0.10.4"

[deps]
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
Expand Down
8 changes: 4 additions & 4 deletions docs/Manifest.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"
deps = ["Compat", "LinearAlgebra", "SparseArrays"]
path = ".."
uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
version = "0.9.44"
version = "0.10.1"

[[Compat]]
deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "SHA", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"]
Expand All @@ -34,10 +34,10 @@ deps = ["Random", "Serialization", "Sockets"]
uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b"

[[DocStringExtensions]]
deps = ["LibGit2", "Markdown", "Pkg", "Test"]
git-tree-sha1 = "9d4f64f79012636741cf01133158a54b24924c32"
deps = ["LibGit2"]
git-tree-sha1 = "a32185f5428d3986f47c2ab78b1f216d5e6cc96f"
uuid = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
version = "0.8.4"
version = "0.8.5"

[[DocThemeIndigo]]
deps = ["Sass"]
Expand Down
1 change: 1 addition & 0 deletions docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ makedocs(
pages=[
"Introduction" => "index.md",
"FAQ" => "FAQ.md",
"Rule configurations and calling back into AD" => "config.md",
"Writing Good Rules" => "writing_good_rules.md",
"Complex Numbers" => "complex.md",
"Deriving Array Rules" => "arrays.md",
Expand Down
7 changes: 7 additions & 0 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,13 @@ add!!
ChainRulesCore.is_inplaceable_destination
```

## RuleConfig
```@autodocs
Modules = [ChainRulesCore]
Pages = ["config.jl"]
Private = false
```

## Internal
```@docs
ChainRulesCore.AbstractTangent
Expand Down
123 changes: 123 additions & 0 deletions docs/src/config.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
# [Rule configurations and calling back into AD](@id config)

[`RuleConfig`](@ref) is a method for making rules conditionally defined based on the presence of certain features in the AD system.
One key such feature is the ability to perform AD either in forwards or reverse mode or both.

This is done with a trait-like system (not Holy Traits), where the `RuleConfig` has a union of types as its only type-parameter.
Where each type represents a particular special feature of this AD.
To indicate that the AD system has a special property, its `RuleConfig` should be defined as:
```julia
struct MyADRuleConfig <: RuleConfig{Union{Feature1, Feature2}} end
```
And rules that should only be defined when an AD has a particular special property write:
```julia
rrule(::RuleConfig{>:Feature1}, f, args...) = # rrule that should only be define for ADs with `Feature1`

frule(::RuleConfig{>:Union{Feature1,Feature2}}, f, args...) = # frule that should only be define for ADs with both `Feature1` and `Feature2`
```

A prominent use of this is in declaring that the AD system can, or cannot support being called from within the rule definitions.

## Declaring support for calling back into ADs

To declare support or lack of support for forward and reverse-mode, use the two pairs of complementary types.
For reverse mode: [`HasReverseMode`](@ref), [`NoReverseMode`](@ref).
For forwards mode: [`HasForwardsMode`](@ref), [`NoForwardsMode`](@ref).
AD systems that support any calling back into AD should have one from each set.

If an AD `HasReverseMode`, then it must define [`rrule_via_ad`](@ref) for that RuleConfig subtype.
Similarly, if an AD `HasForwardsMode` then it must define [`frule_via_ad`](@ref) for that RuleConfig subtype.

For example:
```julia
struct MyReverseOnlyADRuleConfig <: RuleConfig{Union{HasReverseMode, NoForwardsMode}} end

function ChainRulesCore.rrule_via_ad(::MyReverseOnlyADRuleConfig, f, args...)
...
return y, pullback
end
```

Note that it is not actually required that the same AD is used for forward and reverse.
For example [Nabla.jl](https://github.com/invenia/Nabla.jl/) is a reverse mode AD.
It might declare that it `HasForwardsMode`, and then define a wrapper around [ForwardDiff.jl](https://github.com/JuliaDiff/ForwardDiff.jl) in order to provide that capacity.

## Writing rules that call back into AD

To define e.g. rules for higher order functions, it is useful to be able to call back into the AD system to get it to do some work for you.

For example the rule for reverse mode AD for `map` might like to use forward mode AD if one is available.
Particularly for the case where only a single input collection is being mapped over.
In that case we know the most efficient way to compute that sub-program is in forwards, as each call with-in the map only takes a single input.

Note: the following is not the most efficient rule for `map` via forward, but attempts to be clearer for demonstration purposes.

```julia
function rrule(config::RuleConfig{>:HasForwardsMode}, ::typeof(map), f::Function, x::Array{<:Real})
# real code would support functors/closures, but in interest of keeping example short we exclude it:
@assert (fieldcount(typeof(f)) == 0) "Functors/Closures are not supported"

y_and_ẏ = map(x) do xi
frule_via_ad(config, (NoTangent(), one(xi)), f, xi)
end
y = first.(y_and_ẏ)
= last.(y_and_ẏ)

pullback_map(ȳ) = NoTangent(), NoTangent(), ȳ .*
return y, pullback_map
end
```

## Writing rules that depend on other special requirements of the AD.

The `>:HasReverseMode` and `>:HasForwardsMode` are two examples of special properties that a `RuleConfig` could allow.
Others could also exist, but right now they are the only two.
It is likely that in the future such will be provided for e.g. mutation support.

Such a thing would look like:
```julia
struct SupportsMutation end

function rrule(
::RuleConfig{>:SupportsMutatation}, typeof(push!), x::Vector
)
y = push!(x)

function push!_pullback(ȳ)
pop!(x) # undo change to primal incase it is used in another pullback we haven't called yet
pop!(ȳ) # accumulate gradient via mutating ȳ, then return ZeroTangent
return NoTangent(), ZeroTangent()
end

return y, push!_pullback
end
```
and it would be used in the AD e.g. as follows:
```julia
struct EnzymeRuleConfig <: RuleConfig{Union{SupportsMutation, HasReverseMode, NoForwardsMode}}
```

Note: you can only depend on the presence of a feature, not its absence.
This means we may need to define features and their compliments, when one is not the obvious default (as in the fast of [`HasReverseMode`](@ref)/[`NoReverseMode`](@ref) and [`HasForwardsMode`](@ref)/[`NoForwardsMode`](@ref).).


Such special properties generally should only be defines in `ChainRulesCore`.
(Theoretically, they could be defined elsewhere, but the AD and the package containing the rule need to load them, and ChainRulesCore is the place for things like that.)


## Writing rules that are only for your own AD

A special case of the above is writing rules that are defined only for your own AD.
Rules which otherwise would be type-piracy, and would affect other AD systems.
This could be done via making up a special property type and dispatching on it.
But there is no need, as we can dispatch on the `RuleConfig` subtype directly.

For example in order to avoid mutation in nested AD situations, Zygote might want to have a rule for [`add!!`](@ref) that makes it just do `+`.

```julia
struct ZygoteConfig <: RuleConfig{Union{}} end

rrule(::ZygoteConfig, typeof(ChainRulesCore.add!!), a, b) = a+b, Δ->(NoTangent(), Δ, Δ)
```

As an alternative to rules only for one AD, would be to add new special property definitions to ChainRulesCore (as described above) which would capture what makes that AD special.
7 changes: 6 additions & 1 deletion src/ChainRulesCore.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,11 @@ using SparseArrays: SparseVector, SparseMatrixCSC
using Compat: hasfield

export frule, rrule # core function
export @non_differentiable, @scalar_rule, @thunk, @not_implemented # definition helper macros
# rule configurations
export RuleConfig, HasReverseMode, NoReverseMode, HasForwardsMode, NoForwardsMode
export frule_via_ad, rrule_via_ad
# definition helper macros
export @non_differentiable, @scalar_rule, @thunk, @not_implemented
export canonicalize, extern, unthunk # differential operations
export add!! # gradient accumulation operations
# differentials
Expand All @@ -23,6 +27,7 @@ include("differentials/notimplemented.jl")
include("differential_arithmetic.jl")
include("accumulation.jl")

include("config.jl")
include("rules.jl")
include("rule_definition_tools.jl")

Expand Down
92 changes: 92 additions & 0 deletions src/config.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
"""
RuleConfig{T}
The configuration for what rules to use.
`T`: **traits**. This should be a `Union` of all special traits needed for rules to be
allowed to be defined for your AD. If nothing special this should be set to `Union{}`.
**AD authors** should define a subtype of `RuleConfig` to use when calling `frule`/`rrule`.
**Rule authors** can dispatch on this config when defining rules.
For example:
```julia
# only define rrule for `pop!` on AD systems where mutation is supported.
rrule(::RuleConfig{>:SupportsMutation}, typeof(pop!), ::Vector) = ...
# this definition of map is for any AD that defines a forwards mode
rrule(conf::RuleConfig{>:HasForwardsMode}, typeof(map), ::Vector) = ...
# this definition of map is for any AD that only defines a reverse mode.
# It is not as good as the rrule that can be used if the AD defines a forward-mode as well.
rrule(conf::RuleConfig{>:Union{NoForwardsMode, HasReverseMode}}, typeof(map), ::Vector) = ...
```
For more details see [rule configurations and calling back into AD](@ref config).
"""
abstract type RuleConfig{T} end

# Broadcast like a scalar
Base.Broadcast.broadcastable(config::RuleConfig) = Ref(config)

abstract type ReverseModeCapability end

"""
HasReverseMode
This trait indicates that a `RuleConfig{>:HasReverseMode}` can perform reverse mode AD.
If it is set then [`rrule_via_ad`](@ref) must be implemented.
"""
struct HasReverseMode <: ReverseModeCapability end

"""
NoReverseMode
This is the complement to [`HasReverseMode`](@ref). To avoid ambiguities [`RuleConfig`]s
that do not support performing reverse mode AD should be `RuleConfig{>:NoReverseMode}`.
"""
struct NoReverseMode <: ReverseModeCapability end

abstract type ForwardsModeCapability end

"""
HasForwardsMode
This trait indicates that a `RuleConfig{>:HasForwardsMode}` can perform forward mode AD.
If it is set then [`frule_via_ad`](@ref) must be implemented.
"""
struct HasForwardsMode <: ForwardsModeCapability end

"""
NoForwardsMode
This is the complement to [`HasForwardsMode`](@ref). To avoid ambiguities [`RuleConfig`]s
that do not support performing forwards mode AD should be `RuleConfig{>:NoForwardsMode}`.
"""
struct NoForwardsMode <: ForwardsModeCapability end


"""
frule_via_ad(::RuleConfig{>:HasForwardsMode}, ȧrgs, f, args...; kwargs...)
This function has the same API as [`frule`](@ref), but operates via performing forwards mode
automatic differentiation.
Any `RuleConfig` subtype that supports the [`HasForwardsMode`](@ref) special feature must
provide an implementation of it.
See also: [`rrule_via_ad`](@ref), [`RuleConfig`](@ref) and the documentation on
[rule configurations and calling back into AD](@ref config)
"""
function frule_via_ad end

"""
rrule_via_ad(::RuleConfig{>:HasReverseMode}, f, args...; kwargs...)
This function has the same API as [`rrule`](@ref), but operates via performing reverse mode
automatic differentiation.
Any `RuleConfig` subtype that supports the [`HasReverseMode`](@ref) special feature must
provide an implementation of it.
See also: [`frule_via_ad`](@ref), [`RuleConfig`](@ref) and the documentation on
[rule configurations and calling back into AD](@ref config)
"""
function rrule_via_ad end
27 changes: 21 additions & 6 deletions src/rules.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
frule((Δf, Δx...), f, x...)
frule([::RuleConfig,] (Δf, Δx...), f, x...)
Expressing the output of `f(x...)` as `Ω`, return the tuple:
Expand Down Expand Up @@ -50,15 +50,21 @@ So this is actually a [`Tangent`](@ref):
```jldoctest frule
julia> Δsincosx
Tangent{Tuple{Float64, Float64}}(0.6795498147167869, -0.7336293678134624)
```.
```
The optional [`RuleConfig`](@ref) option allows specifying frules only for AD systems that
support given features. If not needed, then it can be omitted and the `frule` without it
will be hit as a fallback. This is the case for most rules.
See also: [`rrule`](@ref), [`@scalar_rule`](@ref)
See also: [`rrule`](@ref), [`@scalar_rule`](@ref), [`RuleConfig`](@ref)
"""
frule(::Any, ::Vararg{Any}; kwargs...) = nothing
frule(::Any, ::Any, ::Vararg{Any}; kwargs...) = nothing

# if no config is present then fallback to config-less rules
frule(::RuleConfig, ȧrgs, f, args...; kwargs...) = frule(ȧrgs, f, args...; kwargs...)

"""
rrule(f, x...)
rrule([::RuleConfig,] f, x...)
Expressing `x` as the tuple `(x₁, x₂, ...)` and the output tuple of `f(x...)`
as `Ω`, return the tuple:
Expand Down Expand Up @@ -101,10 +107,19 @@ julia> hypot_pullback(1) == (NoTangent(), (x / hypot(x, y)), (y / hypot(x, y)))
true
```
See also: [`frule`](@ref), [`@scalar_rule`](@ref)
The optional [`RuleConfig`](@ref) option allows specifying rrules only for AD systems that
support given features. If not needed, then it can be omitted and the `rrule` without it
will be hit as a fallback. This is the case for most rules.
See also: [`frule`](@ref), [`@scalar_rule`](@ref), [`RuleConfig`](@ref)
"""
rrule(::Any, ::Vararg{Any}) = nothing

# if no config is present then fallback to config-less rules
rrule(::RuleConfig, f, args...; kwargs...) = rrule(f, args...; kwargs...)
# TODO do we need to do something for kwargs special here for performance?
# See: https://github.com/JuliaDiff/ChainRulesCore.jl/issues/368

# Manual fallback for keyword arguments. Usually this would be generated by
#
# rrule(::Any, ::Vararg{Any}; kwargs...) = nothing
Expand Down
Loading

2 comments on commit 7947bed

@oxinabox
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/38665

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.10.4 -m "<description of version>" 7947bed6722e488d8fb774cd55351785bd37f0d3
git push origin v0.10.4

Please sign in to comment.