-
-
Notifications
You must be signed in to change notification settings - Fork 212
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Canonicalize the Composite before changing it to a NamedTuple #926
Conversation
Looks like we either need JuliaDiff/ChainRulesCore.jl#321 or use maybe_canonicalized = isabstracttype(P) ? x : canonicalize(x)
xp = map(wrap_chainrules_output, maybe_canonicalized) instead. It seems to be able to do type inference alright. |
Seems a bit hairy. Could you give me an idea of what |
Yeah I'd prefer the issue but there might be reasons not to do that (I can't think of any, but there might be). To understand The zeros for other fields are implicit, meaning it is possible to sum julia> Composite{MyStruct}(;a=ȳ) + Composite{MyStruct}(;b=ȳ2)
Composite{MyStruct}(;a=ȳ, b=ȳ2)
julia> canonicalize(Composite{MyStruct}(;a=ȳ))
Composite{MyStruct}(;a=ȳ, b=Zero()) This relies on knowing the primal type ( *I think we can get rid of this once Zygote uses ChainRules differential types internally |
I am pretty sure we should do this. I commented on this here: #861 (comment) |
I've added ChainRulesCore as a direct dependency to be able to specify a compat bound. I don't have merge rights but there is nothing I want to add, so please go ahead once you think it's ready |
the GPU tests for o Cholesky weem flaky See if they work for bores Bors r+ |
Same test failed in a different PR which only added another test. Not sure what to think about that, but I don't think it's related to either of the PRs? |
Merge conflict. |
Could you try again please? |
bors r+ |
Build succeeded: |
Closes #922.
The issue occurs when the following are all true:
In that case, the two gradients that originate from an rrule are
Composite
s with disjoint fields. When they are transformed to aNamedTuple
, the Zygote internal representation of derivatives w.r.t. structs, these twoNamedTuple
s have disjoint sets of keys, and are notaccum
ulated correctly.By adding
canonicalize(x)
, theComposite
gets explicitZero()
fields, which means the resultingNamedTuple
s have the complete set of fields.- blocked by JuliaDiff/ChainRules.jl#390 (tests only)- need to change compat once the above is merged