Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

port rule definitions to ChainRulesCore #242

Merged
merged 16 commits into from
Dec 27, 2020

Conversation

simeonschaub
Copy link
Member

The current solution of defining these in Zygote is quite unsatisfying and makes working on these unnecessarily complicated and it's also a bit weird that Zygote depends on NNlib at all. Since rules should now be defined using ChainRulesCore anyways, making use of that and moving the definitions into NNlib itself makes the most sense to me. This will require careful coordination with a new release of Zygote removing these rules, because tests will now fail because of circular dependencies. The best solution might be to add upper bounds on NNlib for all Zygote versions in General retroactively, release a new Zygote version with no dependency on NNlib and then release this. Alternatively, we could tag a breaking release for NNlib, but that seems a bit wrong to me, since these changes shouldn't actually be breaking for anything except Zygote.jl.

Rule definitions are ported from https://github.com/FluxML/Zygote.jl/blob/master/src/lib/nnlib.jl. The tests are mostly just copied from https://github.com/FluxML/Zygote.jl/blob/master/test/gradcheck.jl, eventually we probably want to port these to ChainRulesTestUtils.jl, which should be quite a bit more accurate and thorough.

@simeonschaub simeonschaub marked this pull request as draft November 16, 2020 20:14
@DhairyaLGandhi
Copy link
Member

It makes sense to keep NNlib fairly orthogonal and not have it depend on specific AD tooling. That's just the first thought in my head for this.

Maybe we could have the backwards functions defined here and have zygote have the "glue" adjoint/ rrule defs?

The release work should be fine either way.

@simeonschaub
Copy link
Member Author

simeonschaub commented Nov 17, 2020

It makes sense to keep NNlib fairly orthogonal and not have it depend on specific AD tooling. That's just the first thought in my head for this.

Maybe we could have the backwards functions defined here and have zygote have the "glue" adjoint/ rrule defs?

The idea behind ChainRulesCore.jl is basically that it's orthogonal to any specific AD implementation, it just allows packages to define rules in a generic way that a specific AD tool can use, so I would argue that ChainRulesCore.rrule is already that "glue" code. It's a fairly minimal package, so it shouldn't really have any side effects, other than that Zygote doesn't have to depend on NNlib anymore. I really don't want Zygote to have to know anything about NNlib, because that makes trying out changes to NNlib a bit cumbersome.

@CarloLucibello
Copy link
Member

CarloLucibello commented Nov 19, 2020

The best solution might be to add upper bounds on NNlib for all Zygote versions in General retroactively, release a new Zygote version with no dependency on NNlib and then release this. Alternatively, we could tag a breaking release for NNlib, but that seems a bit wrong to me, since these changes shouldn't actually be breaking for anything except Zygote.jl.

I would go with the easier solution, and just tag breaking releases, I'm sure no one will come after us :)
That said, I'm completely in favor of this change, Zygote should be independent from NNlib and NNlib diff rules should stay in NNlib

@simeonschaub
Copy link
Member Author

Ok, as @mcabbott pointed out in FluxML/Zygote.jl#824 (comment), I think it would make sense to use ZygoteRules instead of ChainRulesCore just for the broadcasted rules, because they are pretty specific to Zygote, not necessarily all AD tools. Other than that, would someone be willing to give this a review?

src/chainrulescore.jl Outdated Show resolved Hide resolved
src/chainrulescore.jl Outdated Show resolved Hide resolved
@simeonschaub simeonschaub marked this pull request as ready for review November 21, 2020 19:03
@simeonschaub
Copy link
Member Author

simeonschaub commented Nov 21, 2020

Ok, this should now be ready from my side. (Probably better to merge the Zygote PR first though, so we can rerun CI here)

@simeonschaub simeonschaub changed the title RFC: port rule definitions to ChainRulesCore port rule definitions to ChainRulesCore Nov 21, 2020
bors bot added a commit to FluxML/Zygote.jl that referenced this pull request Nov 27, 2020
824: move NNlib rules out of Zygote r=CarloLucibello a=simeonschaub

Partner to FluxML/NNlib.jl#242. I talked a bit about how to go about releasing this there, would appreciate any feedback and suggestions.

Co-authored-by: Simeon Schaub <[email protected]>
bors bot added a commit to FluxML/Zygote.jl that referenced this pull request Nov 27, 2020
824: move NNlib rules out of Zygote r=CarloLucibello a=simeonschaub

Partner to FluxML/NNlib.jl#242. I talked a bit about how to go about releasing this there, would appreciate any feedback and suggestions.

Co-authored-by: Simeon Schaub <[email protected]>
bors bot added a commit to FluxML/Zygote.jl that referenced this pull request Nov 27, 2020
824: move NNlib rules out of Zygote r=CarloLucibello a=simeonschaub

Partner to FluxML/NNlib.jl#242. I talked a bit about how to go about releasing this there, would appreciate any feedback and suggestions.

Co-authored-by: Simeon Schaub <[email protected]>
@CarloLucibello
Copy link
Member

how do we solve Compat issues? should we bump NNlib version later?

@simeonschaub
Copy link
Member Author

simeonschaub commented Dec 8, 2020

My thinking was that we first tag Zygote 0.6 and then we merge and release this. Packages like Flux that depend on both Zygote and NNlib should then bump their compat for both simultaneously.

@CarloLucibello
Copy link
Member

Zygote 0.6 is being tagged

@CarloLucibello CarloLucibello added this to the v0.8 milestone Dec 21, 2020
@DhairyaLGandhi
Copy link
Member

DhairyaLGandhi commented Dec 21, 2020

Definitely shouldn't tag without the due diligence, I think

If ci is a consideration, we can use a manifest file, not necessarily tag

@CarloLucibello
Copy link
Member

Weird, tests in the multi-threaded environment are not passing

@simeonschaub
Copy link
Member Author

Yes, I just checked, this reproduces for me locally as well when using two threads.

@simeonschaub
Copy link
Member Author

It looks like there is a race condition somewhere in the pooling code. Who would be best to take a look at that? I am not really familiar with what it is doing.

@simeonschaub
Copy link
Member Author

Ok, doesn't really seem related to threading, I think I misdiagnosed that. I don't think the differences here can be explained by numerical instabilities though, as they are often quite large, but am a bit unsure how to further debug this,

test/zygote.jl Outdated Show resolved Hide resolved
@simeonschaub
Copy link
Member Author

Ok, there's something really sketchy going on with maxpool. I am just wondering why this wasn't caught by Zygote before. I would be really glad if someone who initially implemented this could chime in here. I can't imagine that the gradient would be that unstable, or is there some discontinuity at play here and I shouldn't be using a central difference method? I don't really see why this would be affected by threading though.

@CarloLucibello
Copy link
Member

CarloLucibello commented Dec 26, 2020

Maybe @staticfloat or @thebhatman can help with maxpool?

@CarloLucibello
Copy link
Member

I added the rules for the new softmax interface here #250 . I put the rules close to the method definition instead of using a common chainrules.jl file. Are there any general guidelines or common habits about where to place the rules?

@CarloLucibello
Copy link
Member

CarloLucibello commented Dec 26, 2020

@simeonschaub since this maxpool issue is not strictly related to the porting, we could address it later

@simeonschaub
Copy link
Member Author

Are there any general guidelines or common habits about where to place the rules?

Not really AFAIK. Most packages probably added them as an afterthought, so it's usually in a different file, but I see no reason not to put them in the same file if that makes more sense.

@simeonschaub since this maxpool issue is not strictly related to the porting, we could address it later

Ok, I can try to mark these as broken for now, but it would be good to open a separate issue, since they look serious.

@CarloLucibello
Copy link
Member

Ok, I can try to mark these as broken for now, but it would be good to open a separate issue, since they look serious.

not even worth marking them as broken, since they pass on some platforms and fail on others

@CarloLucibello
Copy link
Member

@simeonschaub merge whenever you are ready

@simeonschaub
Copy link
Member Author

I don't think I have commit rights here. If you merge, it would probably be good to squash merge, since I generally tend to add commits just for better reviewability with the assumption that the PR is going to be squashed anyways. I can also squash myself though, if that makes things easier.

@CarloLucibello
Copy link
Member

yes, please squash then I'll do the merge

@CarloLucibello
Copy link
Member

actually, I didn't now how easy it easy to squash-merge, I'll do that

@CarloLucibello CarloLucibello merged commit 9fc717a into FluxML:master Dec 27, 2020
@simeonschaub simeonschaub deleted the sds/chainrulescore branch December 27, 2020 14:55
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants