Skip to content

Commit

Permalink
Merge #926
Browse files Browse the repository at this point in the history
926: Canonicalize the Composite before changing it to a NamedTuple r=oxinabox a=mzgubic

Closes #922.

The issue occurs when the following are all true:
1) We have a struct with two differentiable fields
2) We need to accumulate the gradients with respect to both fields
3) The gradients that need to be accumulated both originate from an rrule

In that case, the two gradients that originate from an rrule are `Composite`s with disjoint fields. When they are transformed to a `NamedTuple`, the Zygote internal representation of derivatives w.r.t. structs, these two `NamedTuple`s have disjoint sets of keys, and are not `accum`ulated correctly.

By adding `canonicalize(x)`, the `Composite` gets explicit `Zero()` fields, which means the resulting `NamedTuple`s have the complete set of fields. 

~~- blocked by JuliaDiff/ChainRules.jl#390 (tests only)~~
~~- need to change compat once the above is merged~~

Co-authored-by: Miha Zgubic <[email protected]>
Co-authored-by: Miha Zgubic <[email protected]>
  • Loading branch information
3 people authored Mar 25, 2021
2 parents 851a649 + a2c4381 commit 56f4118
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 4 deletions.
6 changes: 4 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
name = "Zygote"
uuid = "e88e6eb3-aa80-5325-afca-941959d7151f"
version = "0.6.5"
version = "0.6.6"

[deps]
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
DiffRules = "b552c78f-8df3-52c6-915a-8e097449b14b"
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
Expand All @@ -22,7 +23,8 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"

[compat]
AbstractFFTs = "0.5, 1.0"
ChainRules = "0.7.49"
ChainRules = "0.7.55"
ChainRulesCore = "0.9.32"
DiffRules = "1.0"
FillArrays = "0.8, 0.9, 0.10, 0.11"
ForwardDiff = "0.10"
Expand Down
2 changes: 1 addition & 1 deletion src/Zygote.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ using LinearAlgebra: copytri!, AbstractTriangular
import ZygoteRules: @adjoint, @adjoint!, AContext, adjoint, _pullback, pullback,
literal_getproperty, literal_getfield

using ChainRules: ChainRules, rrule, unthunk
using ChainRules: ChainRules, rrule, unthunk, canonicalize
using IRTools
using MacroTools, Requires
using MacroTools: @forward
Expand Down
2 changes: 1 addition & 1 deletion src/compiler/chainrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ for T_outer in (:Tuple, :NamedTuple)
# branch that changes output type, because nested AD on that kinda thing makes Zygote less
# than happy.
@eval @inline function wrap_chainrules_output(x::ChainRules.Composite{P, T}) where {P, T<:$T_outer}
xp = map(wrap_chainrules_output, x)
xp = map(wrap_chainrules_output, canonicalize(x))
convert($T_outer, xp)
end
end
Expand Down
20 changes: 20 additions & 0 deletions test/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,26 @@ end
@test pb(1) == (nothing, (x = 1, y = nothing), nothing)
end

@testset "issue #922" begin
# checks whether getproperty gets accumulated correctly
# instead of defining a test function as in the issue, compare the two pullbacks
function two_svds(X::StridedMatrix{<:Union{Real, Complex}})
return svd(X).U * svd(X).V'
end

function one_svd(X::StridedMatrix{<:Union{Real, Complex}})
F = svd(X)
return F.U * F.V'
end

Δoutput = randn(3,2)
X = randn(3,2)

d_two = Zygote.pullback(two_svds, X)[2](Δoutput)
d_one = Zygote.pullback(one_svd, X)[2](Δoutput)
@test d_one == d_two
end

# this test fails if adjoint for literal_getproperty is added
# https://github.com/FluxML/Zygote.jl/issues/922#issuecomment-804128905
@testset "overloaded getproperty" begin
Expand Down

3 comments on commit 56f4118

@oxinabox
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Error while trying to register: "Tag with name v0.6.6 already exists and points to a different commit"

@DhairyaLGandhi
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have 0.6.6 which already has this correct? Do we need another tag?

Please sign in to comment.