-
-
Notifications
You must be signed in to change notification settings - Fork 212
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
WIP: test Zygote-generated gradients using ChainRulesTestUtils test_rrule
#987
Conversation
end | ||
|
||
""" | ||
zygote2differential(x) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This function is essentially a generalisation of wrap_chainrules_input
that preserves the information about the primal type.
I am keeping it separate because it might eventually be moved to ZygoteRules when we support ChainRules types internally in Zygote.
As an aside: That PR is nearly complete (only higher order AD is broken) but 6 months old and would need a massive rebase. The hope is that after we have the possibility of rules for higher order functions by calling back into AD, the number of rules in Zygote would be much smaller and the PR would become easier.
superseeded by #990 |
A real
rrule_f
example to go with JuliaDiff/ChainRulesTestUtils.jl#166.To do:
rrule_via_ad
rrule_via_ad
is used to test a gradient that is difficult to evaluate without finite differencingSome benchmarking shows that
test_rrule
is about 40 times slower thangradtest
. Do we understand why?