-
Notifications
You must be signed in to change notification settings - Fork 34
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix MaternKernel AD, but remove differentiation wrt \nu #425
Conversation
Codecov Report
@@ Coverage Diff @@
## master #425 +/- ##
==========================================
+ Coverage 93.16% 93.18% +0.01%
==========================================
Files 52 52
Lines 1259 1261 +2
==========================================
+ Hits 1173 1175 +2
Misses 86 86
Continue to review full report at Codecov.
|
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
It's unlikely that this will be fixed anytime soon since eg for ChainRules derivatives wrt nu are not implemented in SpecialFunctions because they are more involved and would require additional dependencies. |
The problem is not differentiating through \nu no? The problem is to differentiate through the Bessel function given the distance input. |
That's what's tested and commented on in the test though, isn't it? |
That's true. That's why I am proposing that we just ban differentiation through \nu and stop testing for it. |
Sounds good to me 🙂 |
…/KernelFunctions.jl into st/test_Matern_AD
Would you want to do that explicitly in code? Or just add a comment to the docstring? |
I think adding a comment to the docstring is enough |
Co-authored-by: David Widmann <[email protected]>
We need to add a |
Might be an upstream issue, it seems Zygote tries to convert ChainRulesCore.NotImplemented which by design does not support most operations. |
src/basekernels/matern.jl
Outdated
@inline function kappa(κ::MaternKernel, d::Real) | ||
result = _matern(only(κ.ν), d) | ||
@inline function kappa(k::MaternKernel, d::Real) | ||
nu = ChainRulesCore.@ignore_derivatives only(k.ν) # work-around for Zygote AD |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder whether this is quite the right way to handle this.
In particular, what happens if someone does try to compute gradients w.r.t. k.ν
? Do they silently get the wrong answer?
I wonder whether it would make more sense to try and return a NotImplemented
tangent, see here, which should error if the user ever tries to do anything with the tangent.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, this change would mean that there should be a zero gradient wrt k.\nu
[NB- I find unicode field names so annoying!]. @theogf suggested previously that a note in the docstring would be enough.
I agree in principle returning an error would be better, but I wouldn't be sure off the top of my head how to actually implement that. Could we leave that for an issue/future PR?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
NB- I don't know why the field access confuses Zygote. It's not got anything to do with the gradient not being implemented!?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm really not sure that I'm comfortable with a silently-incorrect gradient -- if I encountered this as a user, I would be furious 😂 (I think). I'll take a look at how to do this on my machine now and provide a suggestion.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if you can fix it right away, that also works:)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Personally I'd still prefer a docstring-explained missing gradient wrt one hyperparameter over no gradients working 😄
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Haha fair enough. If I can't find a fix, we should merge it.
Co-authored-by: willtebbutt <[email protected]>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
…/KernelFunctions.jl into st/test_Matern_AD
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm basically happy with this, please feel free to merge once the commented-out code has been un-commented :)
edit: also, patch bump
to keep track of whether it's working again~Now changed so that it only tests Forward/Reverse diff AD, which are working fine at this point.
Not sure why Zygote AD test is failing, or how to fix it :|~
So..... turns out the Zygote issue was the accessing of the kernel object's field. 🤯
Something like this works fine:
So the solution was to simply wrap the
k.\nu
in aignore_derivatives
... which is fine because the docstring explicitly denies this anyways.