Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix MaternKernel AD, but remove differentiation wrt \nu #425

Merged
merged 21 commits into from
Apr 13, 2022
Merged

Conversation

st--
Copy link
Member

@st-- st-- commented Jan 13, 2022

to keep track of whether it's working again

~Now changed so that it only tests Forward/Reverse diff AD, which are working fine at this point.

Not sure why Zygote AD test is failing, or how to fix it :|~

So..... turns out the Zygote issue was the accessing of the kernel object's field. 🤯

Something like this works fine:

using KernelFunctions, Zygote
@inline function matern_kappa_hardcoded_nu(d::Real)
       result = KernelFunctions._matern(1.234, d)
       return ifelse(iszero(d), one(result), result)
end
Zygote.gradient(matern_kappa_hardcoded_nu, 2.345)

So the solution was to simply wrap the k.\nu in a ignore_derivatives... which is fine because the docstring explicitly denies this anyways.

@codecov
Copy link

codecov bot commented Jan 13, 2022

Codecov Report

Merging #425 (8aee1a4) into master (873aa8d) will increase coverage by 0.01%.
The diff coverage is 100.00%.

@@            Coverage Diff             @@
##           master     #425      +/-   ##
==========================================
+ Coverage   93.16%   93.18%   +0.01%     
==========================================
  Files          52       52              
  Lines        1259     1261       +2     
==========================================
+ Hits         1173     1175       +2     
  Misses         86       86              
Impacted Files Coverage Δ
src/matrix/kernelkroneckermat.jl 100.00% <ø> (ø)
src/basekernels/matern.jl 100.00% <100.00%> (ø)

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 873aa8d...8aee1a4. Read the comment docs.

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
@devmotion
Copy link
Member

It's unlikely that this will be fixed anytime soon since eg for ChainRules derivatives wrt nu are not implemented in SpecialFunctions because they are more involved and would require additional dependencies.

@theogf
Copy link
Member

theogf commented Jan 13, 2022

The problem is not differentiating through \nu no? The problem is to differentiate through the Bessel function given the distance input.
It's quite clear that the derivatives given \nu are too ugly and complicated to deal with and we should make it clear that the kernel is not differentiable by \nu

@devmotion
Copy link
Member

That's what's tested and commented on in the test though, isn't it?

@theogf
Copy link
Member

theogf commented Jan 13, 2022

That's true. That's why I am proposing that we just ban differentiation through \nu and stop testing for it.

@devmotion
Copy link
Member

Sounds good to me 🙂

@st--
Copy link
Member Author

st-- commented Jan 13, 2022

that we just ban differentiation through \nu

Would you want to do that explicitly in code? Or just add a comment to the docstring?

@st-- st-- changed the title test_ADs of MaternKernel again test_ADs of MaternKernel again (but remove differentiation wrt \nu) Jan 13, 2022
@theogf
Copy link
Member

theogf commented Jan 13, 2022

Would you want to do that explicitly in code? Or just add a comment to the docstring?

I think adding a comment to the docstring is enough

@st-- st-- requested review from theogf and devmotion January 14, 2022 10:33
@st--
Copy link
Member Author

st-- commented Jan 19, 2022

We need to add a @non_differentiable somewhere for Zygote...

@devmotion
Copy link
Member

Might be an upstream issue, it seems Zygote tries to convert ChainRulesCore.NotImplemented which by design does not support most operations.

@st-- st-- requested a review from willtebbutt April 13, 2022 06:39
@st-- st-- changed the title test_ADs of MaternKernel again (but remove differentiation wrt \nu) fix MaternKernel AD, but remove differentiation wrt \nu Apr 13, 2022
@inline function kappa(κ::MaternKernel, d::Real)
result = _matern(only(κ.ν), d)
@inline function kappa(k::MaternKernel, d::Real)
nu = ChainRulesCore.@ignore_derivatives only(k.ν) # work-around for Zygote AD
Copy link
Member

@willtebbutt willtebbutt Apr 13, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder whether this is quite the right way to handle this.
In particular, what happens if someone does try to compute gradients w.r.t. k.ν? Do they silently get the wrong answer?

I wonder whether it would make more sense to try and return a NotImplemented tangent, see here, which should error if the user ever tries to do anything with the tangent.

Copy link
Member Author

@st-- st-- Apr 13, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this change would mean that there should be a zero gradient wrt k.\nu [NB- I find unicode field names so annoying!]. @theogf suggested previously that a note in the docstring would be enough.

I agree in principle returning an error would be better, but I wouldn't be sure off the top of my head how to actually implement that. Could we leave that for an issue/future PR?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NB- I don't know why the field access confuses Zygote. It's not got anything to do with the gradient not being implemented!?

Copy link
Member

@willtebbutt willtebbutt Apr 13, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm really not sure that I'm comfortable with a silently-incorrect gradient -- if I encountered this as a user, I would be furious 😂 (I think). I'll take a look at how to do this on my machine now and provide a suggestion.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if you can fix it right away, that also works:)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Personally I'd still prefer a docstring-explained missing gradient wrt one hyperparameter over no gradients working 😄

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Haha fair enough. If I can't find a fix, we should merge it.

src/basekernels/matern.jl Outdated Show resolved Hide resolved
src/basekernels/matern.jl Outdated Show resolved Hide resolved
st-- and others added 3 commits April 13, 2022 15:01
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
@st-- st-- requested a review from willtebbutt April 13, 2022 12:05
src/zygoterules.jl Outdated Show resolved Hide resolved
@willtebbutt willtebbutt mentioned this pull request Apr 13, 2022
Copy link
Member

@willtebbutt willtebbutt left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm basically happy with this, please feel free to merge once the commented-out code has been un-commented :)

edit: also, patch bump

@st-- st-- merged commit 8e805ef into master Apr 13, 2022
@st-- st-- deleted the st/test_Matern_AD branch April 13, 2022 12:42
@devmotion devmotion mentioned this pull request Apr 13, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants