Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: test Zygote-generated gradients using ChainRulesTestUtils test_rrule #987

Closed
wants to merge 13 commits into from
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
AbstractFFTs = "0.5, 1.0"
ChainRules = "0.7.55, 0.8"
ChainRulesCore = "0.9.44, 0.10"
ChainRulesTestUtils = "0.7.1"
DiffRules = "1.0"
FillArrays = "0.8, 0.9, 0.10, 0.11"
ForwardDiff = "0.10"
Expand All @@ -39,6 +40,7 @@ julia = "1.3"

[extras]
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
Expand All @@ -47,4 +49,4 @@ StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["CUDA", "Distances", "FFTW", "FiniteDifferences", "LogExpFunctions", "Test"]
test = ["ChainRulesTestUtils", "CUDA", "Distances", "FFTW", "FiniteDifferences", "LogExpFunctions", "Test"]
2 changes: 2 additions & 0 deletions src/Zygote.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,15 @@ using LinearAlgebra: copytri!, AbstractTriangular
import ZygoteRules: @adjoint, @adjoint!, AContext, adjoint, _pullback, pullback,
literal_getproperty, literal_getfield

using ChainRulesCore
using ChainRules: ChainRules, rrule, unthunk, canonicalize
using IRTools
using MacroTools, Requires
using MacroTools: @forward

import Distributed: pmap, CachingPool, workers
export Params, gradient, jacobian, hessian, diaghessian, pullback, pushforward, @code_adjoint
export rrule_via_ad

const Numeric{T<:Number} = Union{T, AbstractArray{<:T}}

Expand Down
48 changes: 48 additions & 0 deletions src/compiler/chainrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -110,3 +110,51 @@ As per [`chain_rrule`](@ref) but with support for kwargs.
end
return y, kw_zpullback
end

"""
rrule_via_ad(f, args)

Function with the same API as the `ChainRulesCore.rrule`, used for testing Zygote gradients
with `ChainRulesTestUtils.test_rrule` functionality.

```
ChainRulesTestUtils.test_rrule(round, 2.2; rrule_f=rrule_via_ad)
```
"""
function rrule_via_ad(f::Function, args...)
y, pb = pullback(f, args...)
function ad_pullback(Δ)
return (NoTangent(), zygote2differential(pb(wrap_chainrules_output(Δ)), args)...)
end
return y, ad_pullback
end

"""
zygote2differential(x)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function is essentially a generalisation of wrap_chainrules_input that preserves the information about the primal type.

I am keeping it separate because it might eventually be moved to ZygoteRules when we support ChainRules types internally in Zygote.

As an aside: That PR is nearly complete (only higher order AD is broken) but 6 months old and would need a massive rebase. The hope is that after we have the possibility of rules for higher order functions by calling back into AD, the number of rules in Zygote would be much smaller and the PR would become easier.


Convert input `x` from the Zygote format to the ChainRules differential types.
"""
zygote2differential(x, primal) = z2d(x, primal)
zygote2differential(::Nothing, ::Any) = NoTangent()
zygote2differential(t::Tuple, primal::Tuple) = map(z2d, t, primal)
zygote2differential(t::Tuple, primal) = (@warn "primal should be a tuple, not $primal"; return t)
z2d(x, ::Any) = x
z2d(::Nothing, ::Any) = NoTangent()
z2d(a::AbstractArray{<:Number}, primal::AbstractArray{T}) where T = a
z2d(a::AbstractArray, primal::AbstractArray{T}) where T = z2d.(a, primal)
z2d(x::Union{AbstractZero, Tangent}, ::Any) = (difftype_warn(x); return x)
function z2d(t::Tuple, primal::Tuple)
tp::Tuple = map(z2d, t, primal)
primal_type = typeof(primal)
return canonicalize(Tangent{primal_type, typeof(tp)}(tp))
end

function z2d(t::NamedTuple, primal)
primal_type = typeof(primal)
fnames = fieldnames(primal_type)
complete_t = NamedTuple{fnames}(fn in keys(t) ? t[fn] : nothing for fn in fnames)
primals = NamedTuple{fnames}(getfield(primal, fn) for fn in fnames)
tp::NamedTuple = map(z2d, complete_t, primals)
return canonicalize(Tangent{primal_type, typeof(tp)}(tp))
end

40 changes: 39 additions & 1 deletion test/chainrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,6 @@ using Zygote, Test, ChainRules
end
end


@testset "kwarg, with all AbstractZero partials" begin
# while ChainRules always has a partial for every input, Zygote combined them all
# to a single `nothing` if they are all zero-like.
Expand All @@ -212,6 +211,45 @@ using Zygote, Test, ChainRules
@test (nothing,) == Zygote.gradient(x->not_diff_kw_eg(x, 2), 10.4)
@test (nothing,) == Zygote.gradient(x->not_diff_kw_eg(x, 2; kw=2.0), 10.4)
end

@testset "rrule_via_ad" begin
@testset "basic" begin
test_rrule(round, 2.2; rrule_f=rrule_via_ad)
test_rrule(vcat, rand(3), rand(4); rrule_f=rrule_via_ad, check_inferred=false)
test_rrule(getindex, rand(5), 3; rrule_f=rrule_via_ad)
end

@testset "struct" begin
struct Foo
x
y
end
makefoo(a, b) = Foo(a, b)
sumfoo(foo) = foo.x + foo.y

test_rrule(sumfoo, foo; rrule_f=rrule_via_ad, check_inferred=false)
test_rrule(makefoo, 1.0, 2.0; rrule_f=rrule_via_ad, check_inferred=false)
end

@testset "tuples/namedtuples" begin
my_tuple(a, b, c) = (a+b, b+c)
my_namedtuple(a, b, c) = (a=a, b=b, c=0.0)

test_rrule(my_tuple, 1., 2., 3.; rrule_f=rrule_via_ad)
test_rrule(my_namedtuple, 1., 2., 3.; rrule_f=rrule_via_ad)
test_rrule(my_namedtuple, 1., (2.0, "str"), 3.; rrule_f=rrule_via_ad)

test_rrule(sum, (1.0, 2.0, 3.0); rrule_f=rrule_via_ad)
test_rrule(sum, (a=1.0, b=2.0); rrule_f=rrule_via_ad, check_inferred=false)
end

@testset "arrays" begin
nada(x, y) = 1.0
test_rrule(nada, rand(3), rand(2,3); rrule_f=rrule_via_ad)
test_rrule(+, rand(3), rand(3); rrule_f=rrule_via_ad)
test_rrule(*, rand(1, 3), rand(3); rrule_f=rrule_via_ad)
end
end
end

@testset "FastMath support" begin
Expand Down
3 changes: 2 additions & 1 deletion test/lib/array.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using LinearAlgebra

# issue 897
@test gradient(x -> sum(sin, Diagonal(x)), ones(2)) == ([0.5403023058681398, 0.5403023058681398],)
test_rrule(x->sum(sin, Diagonal(x)), ones(2); rrule_f=rrule_via_ad, check_inferred=false)
test_rrule(x->sum(sin, Diagonal(x)), rand(3); rrule_f=rrule_via_ad, check_inferred=false)
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using Zygote, Test
using Zygote: gradient
using CUDA: has_cuda
using ChainRulesTestUtils

if has_cuda()
@testset "CUDA tests" begin
Expand Down