-
-
Notifications
You must be signed in to change notification settings - Fork 212
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
accum NamedTuple sometimes silently drops derivatives #861
Comments
Although looking at what julia> Zygote.gradient(x -> sin(x.a), (a=4,b=4))
((a = -0.6536436208636119, b = nothing),) (Had the gradient not had a field |
This is the case. As I understand it. As such Actually on that basis: Zygote.jl/src/compiler/chainrules.jl Lines 50 to 53 in 4fdb691
We should fix both this not erroring and ChainRules not getting canonicalized |
This is exactly what happened. I returned a subset of the required fields from a particular hand-writen |
Closed by #924 |
This works correctly:
This drops the derivative w.r.t.
b
:Now this is definitely surprising, and I think it's a bug. The only situation in which it's intended behaviour is if it's the case that the derivative w.r.t. some struct should always be a
NamedTuple
containing all of the associated fields, since this would preclude this kind of problem.Note that ChainRule types don't suffer from this issue:
so we can probably nick the
ChainRulesCore
implementation of+
forComposite
s and deploy it foraccum
withNamedTuple
s.The text was updated successfully, but these errors were encountered: