Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

accumulation of gradients #922

Closed
mzgubic opened this issue Mar 22, 2021 · 12 comments · Fixed by #926
Closed

accumulation of gradients #922

mzgubic opened this issue Mar 22, 2021 · 12 comments · Fixed by #926

Comments

@mzgubic
Copy link
Collaborator

mzgubic commented Mar 22, 2021

The issue occurs when the following are all true:

  • We have a struct with two differentiable fields
  • We need to accumulate the gradients with respect to both fields
  • The gradients that need to be accumulated both originate from an rrule

An example is svd factorisation:

using Zygote
using LinearAlgebra
using FiniteDifferences
using Test

# function to test AD against finite differences. Can't wait for https://github.com/JuliaDiff/ChainRulesTestUtils.jl/issues/114
function test_ad(test_function, Δoutput, inputs...; atol=1e-7, rtol=1e-7)
    # Verify that the forwards-pass produces the correct answer.
    output, pb = Zygote.pullback(test_function, inputs...)
    @test output  test_function(inputs...)

    # Compute the adjoints using AD and FiniteDifferences.
    dW_ad = pb(Δoutput)
    dW_fd = FiniteDifferences.j′vp(central_fdm(5, 1), test_function, Δoutput, inputs...)

    # Compare AD and FiniteDifferences results.
    @testset "$(typeof(test_function)) argument $n" for n in eachindex(inputs)
        @test dW_ad[n]  dW_fd[n] atol=atol rtol=rtol
    end
end

function two_svds(X::StridedMatrix{<:Union{Real, Complex}})
    return svd(X).U * svd(X).V'
end

function one_svd(X::StridedMatrix{<:Union{Real, Complex}})
    F = svd(X)
    return F.U * F.V'
end

Δoutput = randn(3,2)
X = randn(3,2)

test_ad(two_svds, Δoutput, X) # works

test_ad(one_svd, Δoutput, X) # fails

  Expression: (dW_ad[n], dW_fd[n], atol = atol, rtol = rtol)
   Evaluated: [0.10086994903629046 0.017821972969422787; -0.025950275563480844 -0.20012541965315742; 0.1827211721197507 0.16897788019829582]  [-1.5134220464481536 -1.9386355487247564; 0.4546889986237118 0.8662734137317394; 0.9222144395826724 1.4004778288991626] (atol=1.0e-7, rtol=1.0e-7)

I think the reason is that the accumulation of gradients is not defined for literal_getproperty, just literal_getfield

Zygote.jl/src/lib/lib.jl

Lines 213 to 226 in 890b6f5

@adjoint function literal_getfield(x, ::Val{f}) where f
val = getfield(x, f)
function back(Δ)
accum_param(__context__, val, Δ) === nothing && return
if isimmutable(x)
((; nt_nothing(x)..., pair(Val(f), Δ, x)...), nothing)
else
dx = grad_mut(__context__, x)
dx[] = (; dx[]..., pair(Val(f), accum(getfield(dx[], f), Δ))...)
return (dx,nothing)
end
end
unwrap(val), back
end

Will submit a PR later today

@willtebbutt
Copy link
Member

willtebbutt commented Mar 22, 2021

@simeonschaub this is your area isn't it? Any thoughts? This is ringing some alarm bells about custom getproperty definitions and such, but I can never remember exactly what the story is.

@DhairyaLGandhi
Copy link
Member

Relevant too - #909

@simeonschaub
Copy link
Member

I think the reason is that the accumulation of gradients is not defined for literal_getproperty, just literal_getfield

This is intentional, literal_getproperty is supposed to fall back to recursing into getproperty, which should call getfield and we forward that pullback to literal_getfield.This currently tends to break inference, but shouldn't give any wrong results. I wonder whether this is because the rule for svd is implemented in ChainRules and Zygote somehow isn't handling the NamedTuple -> Composite conversion and back again correctly, but that's only a hunch.

@mzgubic
Copy link
Collaborator Author

mzgubic commented Mar 22, 2021

Ah, thanks for explaining. In this case, the getproperty for svd is overloaded and does not call getfield. The rrules are defined for getproperty on the SVD factorisation object as well.

I suspect what is happening is that instead of the intended literal_getproperty -> getproperty -> getfield -> literal_getfield chain, what happens is that since getproperty has an rrule defined it never reaches the literal_getfield branch to accumulate the gradients.

May I ask what the reason is for not defining the accumulating version of literal_getfield for literal_getproperty as well?
We could have something like

for getfun in (:getfield, :getproperty)
    literal_getfun = Symbol(:literal_, getfun)
    @eval begin
        @adjoint function $literal_getfun(x, ::Val{f}) where f
          val = $getfun(x, f)
          function back(Δ)
            accum_param(__context__, val, Δ) === nothing && return
            if isimmutable(x)
              ((; nt_nothing(x)..., pair(Val(f), Δ, x)...), nothing)
            else
              dx = grad_mut(__context__, x)
              dx[] = (; dx[]..., pair(Val(f), accum($getfun(dx[], f), Δ))...)
              return (dx,nothing)
            end
          end
          unwrap(val), back
        end
    end
end

@simeonschaub
Copy link
Member

We can't do that since then Zygote will ignore custom getproperty implementations. This is actually what Zygote used to do, so it would just silently return wrong gradients in those cases.

@mzgubic
Copy link
Collaborator Author

mzgubic commented Mar 22, 2021

Hmm, could you please elaborate or provide an MWE that would not work with this definition? We should probably have a test which catches this, if I add the above the tests pass alright.

@simeonschaub
Copy link
Member

Ah, the test case I added in #848 might not have been enough to catch this particular problem. Try your test case in #851, but with y, pb = Zygote._pullback(() -> f.x) instead.

@mzgubic
Copy link
Collaborator Author

mzgubic commented Mar 22, 2021

Hmm, that also works correctly for me as far as I can tell. However, the following example does not work

using Zygote
using ChainRules

struct MyStruct
       a
       b
end

Base.getproperty(ms::MyStruct, s::Symbol) = s === :c ? ms.a + ms.b : getfield(ms, s)

sumall(ms::MyStruct) = ms.a + ms.b + ms.c

ms = MyStruct(1, 2)

julia> Zygote.pullback(sumall, ms)[2](1) # master (correct)
((a = 2, b = 2),)

julia> Zygote.pullback(sumall, ms)[2](1) # adding @adjoint for literal_getproperty (incorrect)
((a = 1, b = 1, c = 1),)

I now understand what the issue is, instead of actually ADing through the custom getproperty, this patch adds an incorrect adjoint for custom getproperty.

@mzgubic
Copy link
Collaborator Author

mzgubic commented Mar 22, 2021

Is accum(x::NamedTuple, y::NamedTuple) working as intended?

julia> x
(V = [0.21740818056553368 -0.4395847588546205; 0.2579092264854329 -0.745896461490187],)

julia> y
(U = [0.6219574992864582 -0.0864685332862285; 0.6579918408396744 0.6443443552838996; 0.7339635503235561 -0.5609274817384491],)

julia> Zygote.accum(x, y)
(V = [0.21740818056553368 -0.4395847588546205; 0.2579092264854329 -0.745896461490187],)

julia> Zygote.accum(y, x)
(U = [0.6219574992864582 -0.0864685332862285; 0.6579918408396744 0.6443443552838996; 0.7339635503235561 -0.5609274817384491],)

It seems as if it should include both U and V in the final answer? As far as I understand accum it is meant to generalise +, i.e. it should be commutative?

@willtebbutt
Copy link
Member

willtebbutt commented Mar 22, 2021

🤦 I encountered this before, I thought we fixed it, but clearly not.

accum behaves slightly differently from +(::Composite, ::Composite). Zygote assumes that all relevant fields will be present in a given NamedTuple, and so doesn't bother to check whether or not it is in fact there, and I think just looks at the names present in the first argument to accum, hence the behaviour you're seeing.

We should definitly be throwing an error here, because Zygote's assumptions are violated.

bors bot added a commit that referenced this issue Mar 23, 2021
924: Throw error when trying to accum two NamedTuples with different keys r=DhairyaLGandhi a=mzgubic

Would have helped here:
#922 (comment)



Co-authored-by: Miha Zgubic <[email protected]>
bors bot added a commit that referenced this issue Mar 23, 2021
925: add test for literal getproperty overwrite r=DhairyaLGandhi a=mzgubic

Another small thing that came from #922 , namely #922 (comment).

Writing a separate @adjoint for `literal_getproperty`, which shouldn't be done, passes all tests at the moment. This test will fail if a custom adjoint is written for `literal_getproperty`.

@simeonschaub this is from our discussion yesterday

Co-authored-by: Miha Zgubic <[email protected]>
@mzgubic
Copy link
Collaborator Author

mzgubic commented Mar 23, 2021

Okay, so the last idea I have to fix this is to:

  1. canonicalize the x here

    @eval @inline function wrap_chainrules_output(x::ChainRules.Composite{P, T}) where {P, T<:$T_outer}
    xp = map(wrap_chainrules_output, x)
    convert($T_outer, xp)
    end

  2. make sure canonicalize works for SVD (which it doesn't atm), by doing something like

function canonicalize(comp::Composite{P, <:NamedTuple{L}}) where {P<:SVD, L}
    nil = (U = Zero(), S = Zero(), V = Zero())
    combined = merge(nil, ChainRulesCore.backing(comp))
    if length(combined) !== fieldcount(P)
        throw(ArgumentError(
            "Composite fields do not match primal fields.\n" *
            "Composite fields: $L. Primal ($P) fields: $(fieldnames(P))"
        ))
    end
    return Composite{P, typeof(combined)}(combined)
end

At the moment, canonicalizing SVD Composites gives

julia> x = Composite{SVD{Float64,Float64,Array{Float64,2}}}(V = [0.6662838804680036 -0.5008780874375076; 0.729969913633919 0.228331721577141],)

julia> ChainRules.canonicalize(x)
ERROR: ArgumentError: Composite fields do not match primal fields.
Composite fields: (:V,). Primal (SVD{Float64,Float64,Array{Float64,2}}) fields: (:U, :S, :Vt)

bors bot added a commit that referenced this issue Mar 23, 2021
925: add test for literal getproperty overwrite r=DhairyaLGandhi a=mzgubic

Another small thing that came from #922 , namely #922 (comment).

Writing a separate @adjoint for `literal_getproperty`, which shouldn't be done, passes all tests at the moment. This test will fail if a custom adjoint is written for `literal_getproperty`.

@simeonschaub this is from our discussion yesterday

Co-authored-by: Miha Zgubic <[email protected]>
@willtebbutt
Copy link
Member

willtebbutt commented Mar 24, 2021

I agree that canonicalzeing where you propose makes sense. I also agree with what you propose in JuliaDiff/ChainRules.jl#389, that improving the SVD adjoints to store Vt rather than V makes a lot of sense.

bors bot added a commit that referenced this issue Mar 25, 2021
925: add test for literal getproperty overwrite r=oxinabox a=mzgubic

Another small thing that came from #922 , namely #922 (comment).

Writing a separate @adjoint for `literal_getproperty`, which shouldn't be done, passes all tests at the moment. This test will fail if a custom adjoint is written for `literal_getproperty`.

@simeonschaub this is from our discussion yesterday

Co-authored-by: Miha Zgubic <[email protected]>
@bors bors bot closed this as completed in 56f4118 Mar 25, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants