Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add lazy kronecker product for matrix kernels, if Kronecker.jl is loaded #364

Merged
merged 24 commits into from
Sep 23, 2021

Conversation

Crown421
Copy link
Member

Summary
Following #354 this PR adds lazy Kronecker products for the IndependentMOKernel and the IntrinsicCoregionMOKernel via an optional dependency on Kronecker.jl. Also includes comments and thoughts from the previous PR.

Proposed changes

  • Add lazy kronecker_kernelmatrix

What alternatives have you considered?
The name for kronecker_kernelmatrix is perhaps not ideal, maybe an additional mo suffix/prefix is needed. I think conflating this with kernelkronmat would be confusing.

I have also considered making the _mo_output_covariance function apply also for the regular kernelmatrix, but this may cause conflicts with #363 , and would have to be done once both are resolved.

Breaking changes
None.

@st--
Copy link
Member

st-- commented Sep 2, 2021

I think conflating this with kernelkronmat would be confusing.

Having looked at kernelkronmat (and despaired a bit at its lack of documentation), I agree that conflating the two might be confusing (for one, they have quite different signatures and use-cases). kernelkronmat is for when I want to evaluate a (1D) kernel on a multi-dimensional grid. It's actually more limited than it could be, e.g. I could get a kronecker matrix also for any product kernel where each component applies to exactly one of the dimensions...

However, for the MO use-case, the calling is exactly the same as for plain kernelmatrix - you just get a more efficient object back. Would there be any reason not to just have kernelmatrix return a Kronecker object, if Kronecker.jl is loaded? This relates to SebastianAment/CovarianceFunctions.jl#2 as well.

@Crown421
Copy link
Member Author

Crown421 commented Sep 4, 2021

However, for the MO use-case, the calling is exactly the same as for plain kernelmatrix - you just get a more efficient object back. Would there be any reason not to just have kernelmatrix return a Kronecker object, if Kronecker.jl is loaded?

This was my preferred option as well, I just had some issues with getting the function to overwrite correctly/ thought about making it more explicit to the user.

However, I was able to sort out the code issues that I had, and have now made some changes that clean up the code substantially I think.

I think what is missing though is clear documentation about this., so I will add that soon.

Copy link
Member

@willtebbutt willtebbutt left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this broadly looks good -- would just like some code moved around I think.

src/mokernels/intrinsiccoregion.jl Outdated Show resolved Hide resolved
Copy link
Member

@willtebbutt willtebbutt left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm happy for this to go in now. Just needs a patch bump, then please feel free to merge and tag when CI passes.

src/mokernels/independent.jl Outdated Show resolved Hide resolved
src/mokernels/intrinsiccoregion.jl Outdated Show resolved Hide resolved
@devmotion
Copy link
Member

The test errors seem real: https://github.com/JuliaGaussianProcesses/KernelFunctions.jl/pull/364/checks?check_run_id=3675219571#step:6:123 (and the same with the stable Julia version)

I assume this is caused by the hardcoded Eye{Bool} (before it was Eye{eltype(Kfeatures)}) since booleans are not differentiable in ChainRules (and hence also Zygote).

@Crown421
Copy link
Member Author

The test errors seem real: https://github.com/JuliaGaussianProcesses/KernelFunctions.jl/pull/364/checks?check_run_id=3675219571#step:6:123 (and the same with the stable Julia version)

I assume this is caused by the hardcoded Eye{Bool} (before it was Eye{eltype(Kfeatures)}) since booleans are not differentiable in ChainRules (and hence also Zygote).

This is very odd, because previously the tests passed, and I didn't think that I changed anything that would affect this.

@willtebbutt
Copy link
Member

willtebbutt commented Sep 22, 2021

Looking at the error, I think the optimal fix is probably just to make the pullbacks defined for pairwise of Delta around here accept any type, rather than just AbstractMatrixs.

My reasoning is as follows: based on the CI logs that @devmotion linked, it looks like a ZeroTangent is somehow making its way into the pairwise_pullback, which means that a gradient has (correctly) been dropped somewhere earlier in the reverse pass. Usually I would expect Zygote to pick up on this and to never call the pullback, but it seems like that's not happening here for some reason and we need to handle the ZeroTangent manually.

edit: just tried this locally and can confirm that it works.

@Crown421
Copy link
Member Author

I just had a look at the successful test from a week (or so) ago (https://github.com/JuliaGaussianProcesses/KernelFunctions.jl/runs/3613631221), and while I am not 100% sure I am looking at the right things I see that Zygote has gone from 0.6.21 to 0.6.22 (this happened 13h ago in fact, https://github.com/FluxML/Zygote.jl/releases).
One closed issues includes work on adjoints for specialized matrices, which may have changed things for this package?

@willtebbutt
Copy link
Member

willtebbutt commented Sep 22, 2021

One closed issues includes work on adjoints for specialized matrices, which may have changed things for this package?

Yeah, that seems like a likely culprit, specfically FluxML/Zygote.jl#1044 . It's 100% a bug fix though, so we've done that classic thing whereby KF only works properly because of a bug in Zygote.

Tbh, the rules in question should probably all be declared no-ops from an AD perspective using @non_differentiable, which will also fix the problem (I suspect that they were written before we had @non_differentiable available to us) but widening the set of acceptable cotangnets from AbstractMatrix to Any will solve the problem for now.

@Crown421
Copy link
Member Author

Tbh, the rules in question should probably all be declared no-ops from an AD perspective using @non_differentiable, which will also fix the problem (I suspect that they were written before we had @non_differentiable available to us) but widening the set of acceptable cotangnets from AbstractMatrix to Any will solve the problem for now.

After some bad handling of of git on my part, I have now made some changes, I hope I understood correctly what you suggested.

@willtebbutt
Copy link
Member

Looks like it's nearly there. Just also needs the same kind of modification here if I'm interpretting the CI logs correctly.

@Crown421
Copy link
Member Author

Looks like it's nearly there. Just also needs the same kind of modification here if I'm interpretting the CI logs correctly.

Ok, tests are passing now. I should be ready to merge now I think (I don't have permission to merge).

@willtebbutt willtebbutt merged commit c76b27d into JuliaGaussianProcesses:master Sep 23, 2021
@willtebbutt
Copy link
Member

Great. @Crown421 I've invited you to join the org. Please just make sure to have read the first bit of the ColPrac 🙂 if you've not already.

@DhairyaLGandhi
Copy link

DhairyaLGandhi commented Sep 23, 2021

This doesn't seem great for Zygote flexibility wise. It shouldn't be returning ChainRules types since most rules shouldn't be written to handle them. We should convert these to Zygote friendly types where something like this happens.

@willtebbutt
Copy link
Member

Sorry @DhairyaLGandhi I don't follow. Which bit in particular are you referring to?

@DhairyaLGandhi
Copy link

DhairyaLGandhi commented Sep 23, 2021

Needing to have adjoints be aware of (/explicitly handle) ChainRules's types, I mean.

@willtebbutt
Copy link
Member

Oh, but we're just writing rrules, so we should only have to worry about ChainRules types, no?

@Crown421 Crown421 deleted the mo-lazy branch September 24, 2021 09:23
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants