-
-
Notifications
You must be signed in to change notification settings - Fork 212
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
accumulation of gradients #922
Comments
@simeonschaub this is your area isn't it? Any thoughts? This is ringing some alarm bells about custom |
Relevant too - #909 |
This is intentional, |
Ah, thanks for explaining. In this case, the I suspect what is happening is that instead of the intended May I ask what the reason is for not defining the accumulating version of for getfun in (:getfield, :getproperty)
literal_getfun = Symbol(:literal_, getfun)
@eval begin
@adjoint function $literal_getfun(x, ::Val{f}) where f
val = $getfun(x, f)
function back(Δ)
accum_param(__context__, val, Δ) === nothing && return
if isimmutable(x)
((; nt_nothing(x)..., pair(Val(f), Δ, x)...), nothing)
else
dx = grad_mut(__context__, x)
dx[] = (; dx[]..., pair(Val(f), accum($getfun(dx[], f), Δ))...)
return (dx,nothing)
end
end
unwrap(val), back
end
end
end |
We can't do that since then Zygote will ignore custom |
Hmm, could you please elaborate or provide an MWE that would not work with this definition? We should probably have a test which catches this, if I add the above the tests pass alright. |
Hmm, that also works correctly for me as far as I can tell. However, the following example does not work using Zygote
using ChainRules
struct MyStruct
a
b
end
Base.getproperty(ms::MyStruct, s::Symbol) = s === :c ? ms.a + ms.b : getfield(ms, s)
sumall(ms::MyStruct) = ms.a + ms.b + ms.c
ms = MyStruct(1, 2)
julia> Zygote.pullback(sumall, ms)[2](1) # master (correct)
((a = 2, b = 2),)
julia> Zygote.pullback(sumall, ms)[2](1) # adding @adjoint for literal_getproperty (incorrect)
((a = 1, b = 1, c = 1),) I now understand what the issue is, instead of actually ADing through the custom |
Is julia> x
(V = [0.21740818056553368 -0.4395847588546205; 0.2579092264854329 -0.745896461490187],)
julia> y
(U = [0.6219574992864582 -0.0864685332862285; 0.6579918408396744 0.6443443552838996; 0.7339635503235561 -0.5609274817384491],)
julia> Zygote.accum(x, y)
(V = [0.21740818056553368 -0.4395847588546205; 0.2579092264854329 -0.745896461490187],)
julia> Zygote.accum(y, x)
(U = [0.6219574992864582 -0.0864685332862285; 0.6579918408396744 0.6443443552838996; 0.7339635503235561 -0.5609274817384491],) It seems as if it should include both |
🤦 I encountered this before, I thought we fixed it, but clearly not.
We should definitly be throwing an error here, because |
924: Throw error when trying to accum two NamedTuples with different keys r=DhairyaLGandhi a=mzgubic Would have helped here: #922 (comment) Co-authored-by: Miha Zgubic <[email protected]>
925: add test for literal getproperty overwrite r=DhairyaLGandhi a=mzgubic Another small thing that came from #922 , namely #922 (comment). Writing a separate @adjoint for `literal_getproperty`, which shouldn't be done, passes all tests at the moment. This test will fail if a custom adjoint is written for `literal_getproperty`. @simeonschaub this is from our discussion yesterday Co-authored-by: Miha Zgubic <[email protected]>
Okay, so the last idea I have to fix this is to:
function canonicalize(comp::Composite{P, <:NamedTuple{L}}) where {P<:SVD, L}
nil = (U = Zero(), S = Zero(), V = Zero())
combined = merge(nil, ChainRulesCore.backing(comp))
if length(combined) !== fieldcount(P)
throw(ArgumentError(
"Composite fields do not match primal fields.\n" *
"Composite fields: $L. Primal ($P) fields: $(fieldnames(P))"
))
end
return Composite{P, typeof(combined)}(combined)
end At the moment, canonicalizing SVD Composites gives julia> x = Composite{SVD{Float64,Float64,Array{Float64,2}}}(V = [0.6662838804680036 -0.5008780874375076; 0.729969913633919 0.228331721577141],)
julia> ChainRules.canonicalize(x)
ERROR: ArgumentError: Composite fields do not match primal fields.
Composite fields: (:V,). Primal (SVD{Float64,Float64,Array{Float64,2}}) fields: (:U, :S, :Vt) |
925: add test for literal getproperty overwrite r=DhairyaLGandhi a=mzgubic Another small thing that came from #922 , namely #922 (comment). Writing a separate @adjoint for `literal_getproperty`, which shouldn't be done, passes all tests at the moment. This test will fail if a custom adjoint is written for `literal_getproperty`. @simeonschaub this is from our discussion yesterday Co-authored-by: Miha Zgubic <[email protected]>
I agree that |
925: add test for literal getproperty overwrite r=oxinabox a=mzgubic Another small thing that came from #922 , namely #922 (comment). Writing a separate @adjoint for `literal_getproperty`, which shouldn't be done, passes all tests at the moment. This test will fail if a custom adjoint is written for `literal_getproperty`. @simeonschaub this is from our discussion yesterday Co-authored-by: Miha Zgubic <[email protected]>
The issue occurs when the following are all true:
An example is svd factorisation:
I think the reason is that the accumulation of gradients is not defined for
literal_getproperty
, justliteral_getfield
Zygote.jl/src/lib/lib.jl
Lines 213 to 226 in 890b6f5
Will submit a PR later today
The text was updated successfully, but these errors were encountered: