-
-
Notifications
You must be signed in to change notification settings - Fork 212
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Use ChainRules RuleConfig #990
Conversation
Still trying to debug why it won't pick up new rrules til a The following is what happens right now:
The problem is this line
but it should be (shown after doing a
|
Alternative solution for detecting if rules don't exist, since we can't ose First check if we are hitting the Otherwise, we have hit a rule and we need no edges (because we will be returning an Expr that calls the rule, so natural edges apply). |
This reverts commit 388bd0f.
The core of this now works. All the tests in |
This is all passing for me locally, so just need JuliaDiff/ChainRules.jl#441 merged and it will pass on CI. @mzgubic will review tomorrow with aim being to merge it tomorrow |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, looks great, excited about this!
One remaining concern is how do we opt out of using rules if we can't rely on the return value? I guess one way is to call f/rrule_via_ad
inside the rule that we are opting out from. But that is far from ideal as it flops between the rule and AD and this results in redundant differential type changes. Is there a better way?
Correct, that concern remains.
This is not trivial since, it will (if nothing special is done) just infinite-loop. |
1004: ensure `sum(f,x)` works on GPU r=DhairyaLGandhi a=oxinabox Seems like there is some pains from #990 re:GPU. In particular we broke DiffEqSensitivity https://buildkite.com/julialang/diffeqsensitivity-dot-jl/builds/169#d254017e-e824-4d9c-854d-f3b348395599/411-877 @ChrisRackauckas 's "M"WE is ``` using DiffEqFlux, OrdinaryDiffEq, DiffEqSensitivity using CUDA, Test, Zygote CUDA.allowscalar(false) H = CuArray(rand(Float32, 2, 2)) ann = FastChain(FastDense(1, 4, tanh)) p = initial_params(ann) function func(x, p, t) ann([t],p)[1]*H*x end x0 = CuArray(rand(Float32, 2)) x1 = CuArray(rand(Float32, 2)) prob = ODEProblem(func, x0, (0.0f0, 1.0f0)) function evolve(p) solve(prob, Tsit5(), p=p, save_start=false, save_everystep=false, abstol=1e-4, reltol=1e-4, sensealg=QuadratureAdjoint(autojacvec=ZygoteVJP())).u[1] end function cost(p) x = evolve(p) c = sum(abs,x - x1) #println(c) c end grad = Zygote.gradient(cost,p)[1] @test !iszero(grad[1]) @test iszero(grad[2:4]) @test !iszero(grad[5]) @test iszero(grad[6:end]) ``` I am hoping we can get it to fail with just `sum(f, xs)` (which I have added to tests)} I can't run GPU locally which makes testing this hard. If I have to I will spin up an EC2 instance, but I would really rather not. I think what is going on is, from looking at [the logs](https://buildkite.com/julialang/diffeqsensitivity-dot-jl/builds/169#d254017e-e824-4d9c-854d-f3b348395599/411-877) The error happens in during the forward pass. In particular here https://github.com/JuliaDiff/ChainRules.jl/blob/52a0eeadf8d19bff491f224517b7b064ce1ba378/src/rulesets/Base/mapreduce.jl#L46 I think this was why Zygote implemented the pullback of sum(f, x) as sum(f.(x)) (which is slower and more allocate-y than our never version) so that it could hit the code that Zygote has special for CUDA that does forwards-mode. (Which means it doesn't need the Context object containing the IdDict) So I think the solution in short-term is probably to add the old rule for sum back in (but for CuArray only) here. https://github.com/FluxML/Zygote.jl/blob/531da8bb7753f46294bc13f9d2a2fdd54917f926/src/lib/broadcast.jl#L244 ``` # Make sure sum(f, ::CuArray) uses forward mode broadcast AD defined above # Not the ChainRules.rrule which will use the Zygote.Context and thus not be GPU safe @adjoint function sum(f, xs::CuArray; kws...) @Assert !haskey(kws, :init) # TODO add init support (julia 1.6) return pullback(__context__, (f, xs) -> sum(f.(xs); kws...), f, xs) end ``` In the longer-term, we will probably default to doing the f from sum(f, xs) in forward-mode anyway. So Zygote's rule config can be updated to say that it does use ForwardDiff.jl for it's frule_via_ad. Co-authored-by: Lyndon White <[email protected]>
This PR's commit history is pretty gross.
built on top of the merge of #987 and #967
and then hacks them up.
Right now its kinda just a sketch.
It might compile and work, but I assume it doesn't.
This needs: JuliaDiff/ChainRules.jl#441 which it uses as a test-case.