diff --git a/Project.toml b/Project.toml index b46577795..a4086ec19 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRulesCore" uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" -version = "0.10.3" +version = "0.10.4" [deps] Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" diff --git a/docs/Manifest.toml b/docs/Manifest.toml index 67ca344dd..22a04e662 100644 --- a/docs/Manifest.toml +++ b/docs/Manifest.toml @@ -13,7 +13,7 @@ uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f" deps = ["Compat", "LinearAlgebra", "SparseArrays"] path = ".." uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" -version = "0.9.44" +version = "0.10.1" [[Compat]] deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "SHA", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"] @@ -34,10 +34,10 @@ deps = ["Random", "Serialization", "Sockets"] uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b" [[DocStringExtensions]] -deps = ["LibGit2", "Markdown", "Pkg", "Test"] -git-tree-sha1 = "9d4f64f79012636741cf01133158a54b24924c32" +deps = ["LibGit2"] +git-tree-sha1 = "a32185f5428d3986f47c2ab78b1f216d5e6cc96f" uuid = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" -version = "0.8.4" +version = "0.8.5" [[DocThemeIndigo]] deps = ["Sass"] diff --git a/docs/make.jl b/docs/make.jl index ea4876bb4..0cbcfc717 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -47,6 +47,7 @@ makedocs( pages=[ "Introduction" => "index.md", "FAQ" => "FAQ.md", + "Rule configurations and calling back into AD" => "config.md", "Writing Good Rules" => "writing_good_rules.md", "Complex Numbers" => "complex.md", "Deriving Array Rules" => "arrays.md", diff --git a/docs/src/api.md b/docs/src/api.md index fc6ac50ab..5fcd53400 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -34,6 +34,13 @@ add!! ChainRulesCore.is_inplaceable_destination ``` +## RuleConfig +```@autodocs +Modules = [ChainRulesCore] +Pages = ["config.jl"] +Private = false +``` + ## Internal ```@docs ChainRulesCore.AbstractTangent diff --git a/docs/src/config.md b/docs/src/config.md new file mode 100644 index 000000000..8997bd7a4 --- /dev/null +++ b/docs/src/config.md @@ -0,0 +1,123 @@ +# [Rule configurations and calling back into AD](@id config) + +[`RuleConfig`](@ref) is a method for making rules conditionally defined based on the presence of certain features in the AD system. +One key such feature is the ability to perform AD either in forwards or reverse mode or both. + +This is done with a trait-like system (not Holy Traits), where the `RuleConfig` has a union of types as its only type-parameter. +Where each type represents a particular special feature of this AD. +To indicate that the AD system has a special property, its `RuleConfig` should be defined as: +```julia +struct MyADRuleConfig <: RuleConfig{Union{Feature1, Feature2}} end +``` +And rules that should only be defined when an AD has a particular special property write: +```julia +rrule(::RuleConfig{>:Feature1}, f, args...) = # rrule that should only be define for ADs with `Feature1` + +frule(::RuleConfig{>:Union{Feature1,Feature2}}, f, args...) = # frule that should only be define for ADs with both `Feature1` and `Feature2` +``` + +A prominent use of this is in declaring that the AD system can, or cannot support being called from within the rule definitions. + +## Declaring support for calling back into ADs + +To declare support or lack of support for forward and reverse-mode, use the two pairs of complementary types. +For reverse mode: [`HasReverseMode`](@ref), [`NoReverseMode`](@ref). +For forwards mode: [`HasForwardsMode`](@ref), [`NoForwardsMode`](@ref). +AD systems that support any calling back into AD should have one from each set. + +If an AD `HasReverseMode`, then it must define [`rrule_via_ad`](@ref) for that RuleConfig subtype. +Similarly, if an AD `HasForwardsMode` then it must define [`frule_via_ad`](@ref) for that RuleConfig subtype. + +For example: +```julia +struct MyReverseOnlyADRuleConfig <: RuleConfig{Union{HasReverseMode, NoForwardsMode}} end + +function ChainRulesCore.rrule_via_ad(::MyReverseOnlyADRuleConfig, f, args...) + ... + return y, pullback +end +``` + +Note that it is not actually required that the same AD is used for forward and reverse. +For example [Nabla.jl](https://github.com/invenia/Nabla.jl/) is a reverse mode AD. +It might declare that it `HasForwardsMode`, and then define a wrapper around [ForwardDiff.jl](https://github.com/JuliaDiff/ForwardDiff.jl) in order to provide that capacity. + +## Writing rules that call back into AD + +To define e.g. rules for higher order functions, it is useful to be able to call back into the AD system to get it to do some work for you. + +For example the rule for reverse mode AD for `map` might like to use forward mode AD if one is available. +Particularly for the case where only a single input collection is being mapped over. +In that case we know the most efficient way to compute that sub-program is in forwards, as each call with-in the map only takes a single input. + +Note: the following is not the most efficient rule for `map` via forward, but attempts to be clearer for demonstration purposes. + +```julia +function rrule(config::RuleConfig{>:HasForwardsMode}, ::typeof(map), f::Function, x::Array{<:Real}) + # real code would support functors/closures, but in interest of keeping example short we exclude it: + @assert (fieldcount(typeof(f)) == 0) "Functors/Closures are not supported" + + y_and_ẏ = map(x) do xi + frule_via_ad(config, (NoTangent(), one(xi)), f, xi) + end + y = first.(y_and_ẏ) + ẏ = last.(y_and_ẏ) + + pullback_map(ȳ) = NoTangent(), NoTangent(), ȳ .* ẏ + return y, pullback_map +end +``` + +## Writing rules that depend on other special requirements of the AD. + +The `>:HasReverseMode` and `>:HasForwardsMode` are two examples of special properties that a `RuleConfig` could allow. +Others could also exist, but right now they are the only two. +It is likely that in the future such will be provided for e.g. mutation support. + +Such a thing would look like: +```julia +struct SupportsMutation end + +function rrule( + ::RuleConfig{>:SupportsMutatation}, typeof(push!), x::Vector +) + y = push!(x) + + function push!_pullback(ȳ) + pop!(x) # undo change to primal incase it is used in another pullback we haven't called yet + pop!(ȳ) # accumulate gradient via mutating ȳ, then return ZeroTangent + return NoTangent(), ZeroTangent() + end + + return y, push!_pullback +end +``` +and it would be used in the AD e.g. as follows: +```julia +struct EnzymeRuleConfig <: RuleConfig{Union{SupportsMutation, HasReverseMode, NoForwardsMode}} +``` + +Note: you can only depend on the presence of a feature, not its absence. +This means we may need to define features and their compliments, when one is not the obvious default (as in the fast of [`HasReverseMode`](@ref)/[`NoReverseMode`](@ref) and [`HasForwardsMode`](@ref)/[`NoForwardsMode`](@ref).). + + +Such special properties generally should only be defines in `ChainRulesCore`. +(Theoretically, they could be defined elsewhere, but the AD and the package containing the rule need to load them, and ChainRulesCore is the place for things like that.) + + +## Writing rules that are only for your own AD + +A special case of the above is writing rules that are defined only for your own AD. +Rules which otherwise would be type-piracy, and would affect other AD systems. +This could be done via making up a special property type and dispatching on it. +But there is no need, as we can dispatch on the `RuleConfig` subtype directly. + +For example in order to avoid mutation in nested AD situations, Zygote might want to have a rule for [`add!!`](@ref) that makes it just do `+`. + +```julia +struct ZygoteConfig <: RuleConfig{Union{}} end + +rrule(::ZygoteConfig, typeof(ChainRulesCore.add!!), a, b) = a+b, Δ->(NoTangent(), Δ, Δ) +``` + +As an alternative to rules only for one AD, would be to add new special property definitions to ChainRulesCore (as described above) which would capture what makes that AD special. diff --git a/src/ChainRulesCore.jl b/src/ChainRulesCore.jl index 9b2b304fc..2655ee87e 100644 --- a/src/ChainRulesCore.jl +++ b/src/ChainRulesCore.jl @@ -5,7 +5,11 @@ using SparseArrays: SparseVector, SparseMatrixCSC using Compat: hasfield export frule, rrule # core function -export @non_differentiable, @scalar_rule, @thunk, @not_implemented # definition helper macros +# rule configurations +export RuleConfig, HasReverseMode, NoReverseMode, HasForwardsMode, NoForwardsMode +export frule_via_ad, rrule_via_ad +# definition helper macros +export @non_differentiable, @scalar_rule, @thunk, @not_implemented export canonicalize, extern, unthunk # differential operations export add!! # gradient accumulation operations # differentials @@ -23,6 +27,7 @@ include("differentials/notimplemented.jl") include("differential_arithmetic.jl") include("accumulation.jl") +include("config.jl") include("rules.jl") include("rule_definition_tools.jl") diff --git a/src/config.jl b/src/config.jl new file mode 100644 index 000000000..347e05c51 --- /dev/null +++ b/src/config.jl @@ -0,0 +1,92 @@ +""" + RuleConfig{T} + +The configuration for what rules to use. +`T`: **traits**. This should be a `Union` of all special traits needed for rules to be +allowed to be defined for your AD. If nothing special this should be set to `Union{}`. + +**AD authors** should define a subtype of `RuleConfig` to use when calling `frule`/`rrule`. + +**Rule authors** can dispatch on this config when defining rules. +For example: +```julia +# only define rrule for `pop!` on AD systems where mutation is supported. +rrule(::RuleConfig{>:SupportsMutation}, typeof(pop!), ::Vector) = ... + +# this definition of map is for any AD that defines a forwards mode +rrule(conf::RuleConfig{>:HasForwardsMode}, typeof(map), ::Vector) = ... + +# this definition of map is for any AD that only defines a reverse mode. +# It is not as good as the rrule that can be used if the AD defines a forward-mode as well. +rrule(conf::RuleConfig{>:Union{NoForwardsMode, HasReverseMode}}, typeof(map), ::Vector) = ... +``` + +For more details see [rule configurations and calling back into AD](@ref config). +""" +abstract type RuleConfig{T} end + +# Broadcast like a scalar +Base.Broadcast.broadcastable(config::RuleConfig) = Ref(config) + +abstract type ReverseModeCapability end + +""" +HasReverseMode + +This trait indicates that a `RuleConfig{>:HasReverseMode}` can perform reverse mode AD. +If it is set then [`rrule_via_ad`](@ref) must be implemented. +""" +struct HasReverseMode <: ReverseModeCapability end + +""" +NoReverseMode + +This is the complement to [`HasReverseMode`](@ref). To avoid ambiguities [`RuleConfig`]s +that do not support performing reverse mode AD should be `RuleConfig{>:NoReverseMode}`. +""" +struct NoReverseMode <: ReverseModeCapability end + +abstract type ForwardsModeCapability end + +""" +HasForwardsMode + +This trait indicates that a `RuleConfig{>:HasForwardsMode}` can perform forward mode AD. +If it is set then [`frule_via_ad`](@ref) must be implemented. +""" +struct HasForwardsMode <: ForwardsModeCapability end + +""" +NoForwardsMode + +This is the complement to [`HasForwardsMode`](@ref). To avoid ambiguities [`RuleConfig`]s +that do not support performing forwards mode AD should be `RuleConfig{>:NoForwardsMode}`. +""" +struct NoForwardsMode <: ForwardsModeCapability end + + +""" + frule_via_ad(::RuleConfig{>:HasForwardsMode}, ȧrgs, f, args...; kwargs...) + +This function has the same API as [`frule`](@ref), but operates via performing forwards mode +automatic differentiation. +Any `RuleConfig` subtype that supports the [`HasForwardsMode`](@ref) special feature must +provide an implementation of it. + +See also: [`rrule_via_ad`](@ref), [`RuleConfig`](@ref) and the documentation on +[rule configurations and calling back into AD](@ref config) +""" +function frule_via_ad end + +""" + rrule_via_ad(::RuleConfig{>:HasReverseMode}, f, args...; kwargs...) + +This function has the same API as [`rrule`](@ref), but operates via performing reverse mode +automatic differentiation. +Any `RuleConfig` subtype that supports the [`HasReverseMode`](@ref) special feature must +provide an implementation of it. + +See also: [`frule_via_ad`](@ref), [`RuleConfig`](@ref) and the documentation on +[rule configurations and calling back into AD](@ref config) +""" +function rrule_via_ad end diff --git a/src/rules.jl b/src/rules.jl index 391a2debe..232b1c558 100644 --- a/src/rules.jl +++ b/src/rules.jl @@ -1,5 +1,5 @@ """ - frule((Δf, Δx...), f, x...) + frule([::RuleConfig,] (Δf, Δx...), f, x...) Expressing the output of `f(x...)` as `Ω`, return the tuple: @@ -50,15 +50,21 @@ So this is actually a [`Tangent`](@ref): ```jldoctest frule julia> Δsincosx Tangent{Tuple{Float64, Float64}}(0.6795498147167869, -0.7336293678134624) -```. +``` +The optional [`RuleConfig`](@ref) option allows specifying frules only for AD systems that +support given features. If not needed, then it can be omitted and the `frule` without it +will be hit as a fallback. This is the case for most rules. -See also: [`rrule`](@ref), [`@scalar_rule`](@ref) +See also: [`rrule`](@ref), [`@scalar_rule`](@ref), [`RuleConfig`](@ref) """ -frule(::Any, ::Vararg{Any}; kwargs...) = nothing +frule(::Any, ::Any, ::Vararg{Any}; kwargs...) = nothing + +# if no config is present then fallback to config-less rules +frule(::RuleConfig, ȧrgs, f, args...; kwargs...) = frule(ȧrgs, f, args...; kwargs...) """ - rrule(f, x...) + rrule([::RuleConfig,] f, x...) Expressing `x` as the tuple `(x₁, x₂, ...)` and the output tuple of `f(x...)` as `Ω`, return the tuple: @@ -101,10 +107,19 @@ julia> hypot_pullback(1) == (NoTangent(), (x / hypot(x, y)), (y / hypot(x, y))) true ``` -See also: [`frule`](@ref), [`@scalar_rule`](@ref) +The optional [`RuleConfig`](@ref) option allows specifying rrules only for AD systems that +support given features. If not needed, then it can be omitted and the `rrule` without it +will be hit as a fallback. This is the case for most rules. + +See also: [`frule`](@ref), [`@scalar_rule`](@ref), [`RuleConfig`](@ref) """ rrule(::Any, ::Vararg{Any}) = nothing +# if no config is present then fallback to config-less rules +rrule(::RuleConfig, f, args...; kwargs...) = rrule(f, args...; kwargs...) +# TODO do we need to do something for kwargs special here for performance? +# See: https://github.com/JuliaDiff/ChainRulesCore.jl/issues/368 + # Manual fallback for keyword arguments. Usually this would be generated by # # rrule(::Any, ::Vararg{Any}; kwargs...) = nothing diff --git a/test/config.jl b/test/config.jl new file mode 100644 index 000000000..c55c0b6c8 --- /dev/null +++ b/test/config.jl @@ -0,0 +1,163 @@ +# Define a bunch of configs for testing purposes +struct MostBoringConfig <: RuleConfig{Union{}} end + +struct MockForwardsConfig <: RuleConfig{Union{HasForwardsMode, NoReverseMode}} + forward_calls::Vector +end +MockForwardsConfig() = MockForwardsConfig([]) +function ChainRulesCore.frule_via_ad(config::MockForwardsConfig, ȧrgs, f, args...; kws...) + # For testing purposes we only support giving right answer for identity functions + push!(config.forward_calls, (f, args)) + return f(args...; kws...), ȧrgs +end + +struct MockReverseConfig <: RuleConfig{Union{NoForwardsMode, HasReverseMode}} + reverse_calls::Vector +end +MockReverseConfig() = MockReverseConfig([]) +function ChainRulesCore.rrule_via_ad(config::MockReverseConfig, f, args...; kws...) + # For testing purposes we only support giving right answer for identity functions + push!(config.reverse_calls, (f, args)) + pullback_via_ad(ȳ) = NoTangent(), ȳ + return f(args...; kws...), pullback_via_ad +end + + +struct MockBothConfig <: RuleConfig{Union{HasForwardsMode, HasReverseMode}} + forward_calls::Vector + reverse_calls::Vector +end +MockBothConfig() = MockBothConfig([], []) +function ChainRulesCore.frule_via_ad(config::MockBothConfig, ȧrgs, f, args...; kws...) + # For testing purposes we only support giving right answer for identity functions + push!(config.forward_calls, (f, args)) + return f(args...; kws...), ȧrgs +end + +function ChainRulesCore.rrule_via_ad(config::MockBothConfig, f, args...; kws...) + # For testing purposes we only support giving right answer for identity functions + push!(config.reverse_calls, (f, args)) + pullback_via_ad(ȳ) = NoTangent(), ȳ + return f(args...; kws...), pullback_via_ad +end + +############################## + +#define some functions for testing + +@testset "config.jl" begin + @testset "basic fall to two arg verion for $Config" for Config in ( + MostBoringConfig, MockForwardsConfig, MockReverseConfig, MockBothConfig, + ) + counting_id_count = Ref(0) + function counting_id(x) + counting_id_count[]+=1 + return x + end + function ChainRulesCore.rrule(::typeof(counting_id), x) + counting_id_pullback(x̄) = x̄ + return counting_id(x), counting_id_pullback + end + function ChainRulesCore.frule((dself, dx),::typeof(counting_id), x) + return counting_id(x), dx + end + @testset "rrule" begin + counting_id_count[] = 0 + @test rrule(Config(), counting_id, 21.5) !== nothing + @test counting_id_count[] == 1 + end + @testset "frule" begin + counting_id_count[] = 0 + @test frule(Config(), (NoTangent(), 11.2), counting_id, 32.4) !== nothing + @test counting_id_count[] == 1 + end + end + + @testset "hitting forwards AD" begin + do_thing_2(f, x) = f(x) + function ChainRulesCore.frule( + config::RuleConfig{>:HasForwardsMode}, (_, df, dx), ::typeof(do_thing_2), f, x + ) + return frule_via_ad(config, (df, dx), f, x) + end + + @testset "$Config" for Config in (MostBoringConfig, MockReverseConfig) + @test nothing === frule( + Config(), (NoTangent(), NoTangent(), 21.5), do_thing_2, identity, 32.1 + ) + end + + @testset "$Config" for Config in (MockBothConfig, MockForwardsConfig) + bconfig= Config() + @test nothing !== frule( + bconfig, (NoTangent(), NoTangent(), 21.5), do_thing_2, identity, 32.1 + ) + @test bconfig.forward_calls == [(identity, (32.1,))] + end + end + + @testset "hitting reverse AD" begin + do_thing_3(f, x) = f(x) + function ChainRulesCore.rrule( + config::RuleConfig{>:HasReverseMode}, ::typeof(do_thing_3), f, x + ) + return (NoTangent(), rrule_via_ad(config, f, x)...) + end + + + @testset "$Config" for Config in (MostBoringConfig, MockForwardsConfig) + @test nothing === rrule(Config(), do_thing_3, identity, 32.1) + end + + @testset "$Config" for Config in (MockBothConfig, MockReverseConfig) + bconfig= Config() + @test nothing !== rrule(bconfig, do_thing_3, identity, 32.1) + @test bconfig.reverse_calls == [(identity, (32.1,))] + end + end + + @testset "hitting forwards AD from reverse, if available and reverse if not" begin + # this is is the complicated case doing something interesting and pseudo-mixed mode + do_thing_4(f, x) = f(x) + function ChainRulesCore.rrule( + config::RuleConfig{>:HasForwardsMode}, + ::typeof(do_thing_4), + f::Function, + x::Real, + ) + # real code would support functors/closures, but in interest of keeping example short we exclude it: + @assert (fieldcount(typeof(f)) == 0) "Functors/Closures are not supported" + + ẋ = one(x) + y, ẏ = frule_via_ad(config, (NoTangent(), ẋ), f, x) + pullback_via_forwards_ad(ȳ) = NoTangent(), NoTangent(), ẏ*ȳ + return y, pullback_via_forwards_ad + end + function ChainRulesCore.rrule( + config::RuleConfig{>:Union{HasReverseMode, NoForwardsMode}}, + ::typeof(do_thing_4), + f, + x + ) + y, f_pullback = rrule_via_ad(config, f, x) + do_thing_4_pullback(ȳ) = (NoTangent(), f_pullback(ȳ)...) + return y, do_thing_4_pullback + end + + @test nothing === rrule(MostBoringConfig(), do_thing_4, identity, 32.1) + + @testset "$Config" for Config in (MockBothConfig, MockForwardsConfig) + bconfig= Config() + @test nothing !== rrule(bconfig, do_thing_4, identity, 32.1) + @test bconfig.forward_calls == [(identity, (32.1,))] + end + + rconfig= MockReverseConfig() + @test nothing !== rrule(rconfig, do_thing_4, identity, 32.1) + @test rconfig.reverse_calls == [(identity, (32.1,))] + end + + @testset "RuleConfig broadcasts like a scaler" begin + @test (MostBoringConfig() .=> (1,2,3)) isa NTuple{3, Pair{MostBoringConfig,Int}} + end +end diff --git a/test/runtests.jl b/test/runtests.jl index 612e57f9f..72b2d0acf 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -18,4 +18,5 @@ using Test include("rules.jl") include("rule_definition_tools.jl") + include("config.jl") end