-
Notifications
You must be signed in to change notification settings - Fork 89
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix Composite of SVD #390
Fix Composite of SVD #390
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Just minor comments. Feel free to merge when ready.
I've modernised the tests from the manual finite differences tests to the new |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. It's nice that it uses test_rrule
now as well. Feel free to merge if all non-nightly tests pass.
Codecov Report
@@ Coverage Diff @@
## master #390 +/- ##
===========================================
- Coverage 98.45% 87.76% -10.70%
===========================================
Files 20 20
Lines 1745 1291 -454
===========================================
- Hits 1718 1133 -585
- Misses 27 158 +131
Continue to review full report at Codecov.
|
926: Canonicalize the Composite before changing it to a NamedTuple r=oxinabox a=mzgubic Closes #922. The issue occurs when the following are all true: 1) We have a struct with two differentiable fields 2) We need to accumulate the gradients with respect to both fields 3) The gradients that need to be accumulated both originate from an rrule In that case, the two gradients that originate from an rrule are `Composite`s with disjoint fields. When they are transformed to a `NamedTuple`, the Zygote internal representation of derivatives w.r.t. structs, these two `NamedTuple`s have disjoint sets of keys, and are not `accum`ulated correctly. By adding `canonicalize(x)`, the `Composite` gets explicit `Zero()` fields, which means the resulting `NamedTuple`s have the complete set of fields. ~~- blocked by JuliaDiff/ChainRules.jl#390 (tests only)~~ ~~- need to change compat once the above is merged~~ Co-authored-by: Miha Zgubic <[email protected]> Co-authored-by: Miha Zgubic <[email protected]>
Closes #389 , closes #106