Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use ChainRules RuleConfig #990

Merged
merged 37 commits into from
Jun 18, 2021
Merged

Use ChainRules RuleConfig #990

merged 37 commits into from
Jun 18, 2021

Conversation

oxinabox
Copy link
Member

@oxinabox oxinabox commented Jun 9, 2021

This PR's commit history is pretty gross.
built on top of the merge of #987 and #967
and then hacks them up.

Right now its kinda just a sketch.
It might compile and work, but I assume it doesn't.

This needs: JuliaDiff/ChainRules.jl#441 which it uses as a test-case.

test/runtests.jl Outdated Show resolved Hide resolved
Project.toml Show resolved Hide resolved
@oxinabox oxinabox mentioned this pull request Jun 14, 2021
@oxinabox
Copy link
Member Author

Still trying to debug why it won't pick up new rrules til a Zygote.refresh() is done.
Generated functions in general don't see things defined after them,
but we should have the right extra back-edged in place to cause them to see them.
Zygote.refresh() fixes it, but we shouldn't need that, we didn't need it for this case before.

The following is what happens right now:
afterr defining the rrule for cr_inner_demo.

julia> Zygote._pullback(Zygote.Context(), cr_outer_demo, 11)[2](1)
T=Tuple{Zygote.ZygoteRuleConfig{Zygote.Context}, typeof(Main.cr_outer_demo), Int64}; return_type=Nothing
instance=rrule(Zygote.ZygoteRuleConfig{Zygote.Context}, typeof(Main.cr_outer_demo), Int64) from rrule(ChainRulesCore.RuleConfig{T} where T, Any, Any...)
T=Tuple{Zygote.ZygoteRuleConfig{Zygote.Context}, typeof(Main.cr_inner_demo), Int64}; return_type=Nothing
instance=rrule(Zygote.ZygoteRuleConfig{Zygote.Context}, typeof(Main.cr_inner_demo), Int64) from rrule(ChainRulesCore.RuleConfig{T} where T, Any, Any...)
T=Tuple{Zygote.ZygoteRuleConfig{Zygote.Context}, typeof(Base.:(*)), Int64, Int64}; return_type=Tuple{Int64, ChainRules.var"#times_pullback#1490"{Int64, Int64}}
T=Tuple{Zygote.ZygoteRuleConfig{Zygote.Context}, typeof(Base.:(*)), Int64, Int64}; return_type=Tuple{Int64, ChainRules.var"#times_pullback#1490"{Int64, Int64}}
(nothing, 50)

julia> Zygote._pullback(Zygote.Context(), cr_outer_demo, 11)[2](1)  # just uses what it found before by ADing fully wihtout rule
(nothing, 50)

The problem is this line

T=Tuple{Zygote.ZygoteRuleConfig{Zygote.Context}, typeof(Main.cr_inner_demo), Int64}; return_type=Nothing

but it should be (shown after doing a Zygote.refresh()

T=Tuple{Zygote.ZygoteRuleConfig{Zygote.Context}, typeof(Main.cr_inner_demo), Int64}; return_type=Tuple{Int64, Main.var"#cr_inner_demo_pullback#10"}

src/compiler/chainrules.jl Outdated Show resolved Hide resolved
@oxinabox
Copy link
Member Author

Alternative solution for detecting if rules don't exist, since we can't ose Core.Compiler.return_type.

First check if we are hitting the rrule(::RuleConfig, args...) = rrule(args...) redispatcher.
then check if the rrule(args...) would hit the rrule(args...) = nothing fallback.
If that is the case then we have no rule and so need to attach the backedge, to the redispatcher (which will be invalidated in turn by backedged from the fallback anyway, so no need to attach backedges there).

Otherwise, we have hit a rule and we need no edges (because we will be returning an Expr that calls the rule, so natural edges apply).

src/compiler/chainrules.jl Outdated Show resolved Hide resolved
src/compiler/chainrules.jl Outdated Show resolved Hide resolved
@oxinabox oxinabox changed the title WIP: use ChainRules RuleConfig Use ChainRules RuleConfig Jun 16, 2021
@oxinabox
Copy link
Member Author

The core of this now works. All the tests in test/chainrules.jl pass.
Mostly what is left is chasing up edge cases in JuliaDiff/ChainRules.jl#441
of which there are surprisingly many (and getting that merged)

Project.toml Outdated Show resolved Hide resolved
@oxinabox
Copy link
Member Author

oxinabox commented Jun 17, 2021

This is all passing for me locally, so just need JuliaDiff/ChainRules.jl#441 merged and it will pass on CI.

@mzgubic will review tomorrow with aim being to merge it tomorrow

Copy link
Collaborator

@mzgubic mzgubic left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, looks great, excited about this!

One remaining concern is how do we opt out of using rules if we can't rely on the return value? I guess one way is to call f/rrule_via_ad inside the rule that we are opting out from. But that is far from ideal as it flops between the rule and AD and this results in redundant differential type changes. Is there a better way?

Project.toml Show resolved Hide resolved
src/compiler/chainrules.jl Outdated Show resolved Hide resolved
test/chainrules.jl Outdated Show resolved Hide resolved
src/compiler/interface2.jl Show resolved Hide resolved
test/chainrules.jl Outdated Show resolved Hide resolved
@oxinabox
Copy link
Member Author

One remaining concern is how do we opt out of using rules if we can't rely on the return value?

Correct, that concern remains.
It is out of scope for this PR to solve, i had hoped we could solve it incidentally, but it seems the obvious way doens't work.

I guess one way is to call f/rrule_via_ad inside the rule that we are opting out from.

This is not trivial since, it will (if nothing special is done) just infinite-loop.
Maybe invoke can be used to prevent that?

@oxinabox oxinabox merged commit b250d92 into master Jun 18, 2021
@oxinabox oxinabox deleted the ox/ruleconfig branch June 18, 2021 17:01
bors bot added a commit that referenced this pull request Jun 21, 2021
1004: ensure `sum(f,x)` works on GPU r=DhairyaLGandhi a=oxinabox

Seems like there is some pains from #990 re:GPU.
In particular we broke DiffEqSensitivity
https://buildkite.com/julialang/diffeqsensitivity-dot-jl/builds/169#d254017e-e824-4d9c-854d-f3b348395599/411-877

@ChrisRackauckas 's "M"WE is
```
using DiffEqFlux, OrdinaryDiffEq, DiffEqSensitivity
using CUDA, Test, Zygote
CUDA.allowscalar(false)
H = CuArray(rand(Float32, 2, 2))
ann = FastChain(FastDense(1, 4, tanh))
p = initial_params(ann)
function func(x, p, t)
    ann([t],p)[1]*H*x
end
x0 = CuArray(rand(Float32, 2))
x1 = CuArray(rand(Float32, 2))
prob = ODEProblem(func, x0, (0.0f0, 1.0f0))
function evolve(p)
    solve(prob, Tsit5(), p=p, save_start=false,
          save_everystep=false, abstol=1e-4, reltol=1e-4,
          sensealg=QuadratureAdjoint(autojacvec=ZygoteVJP())).u[1]
end
function cost(p)
    x = evolve(p)
    c = sum(abs,x - x1)
    #println(c)
    c
end
grad = Zygote.gradient(cost,p)[1]
@test !iszero(grad[1])
@test iszero(grad[2:4])
@test !iszero(grad[5])
@test iszero(grad[6:end])
```
I am hoping we can get it to fail with just `sum(f, xs)` (which I have added to tests)}
I can't run GPU locally which makes testing this hard.
If I have to I will spin up an EC2 instance, but I would really rather not.

I think what is going on is, from looking at [the logs](https://buildkite.com/julialang/diffeqsensitivity-dot-jl/builds/169#d254017e-e824-4d9c-854d-f3b348395599/411-877)

The error happens in during the forward pass.
In particular here
https://github.com/JuliaDiff/ChainRules.jl/blob/52a0eeadf8d19bff491f224517b7b064ce1ba378/src/rulesets/Base/mapreduce.jl#L46
I think this was why Zygote implemented
the pullback of sum(f, x) as sum(f.(x)) (which is slower and more allocate-y than our never version)
so that it could hit the code that Zygote has special for CUDA that does forwards-mode.
(Which means it doesn't need the Context object containing the IdDict)
So I think the solution in short-term is probably to add the old rule for sum back in (but for CuArray only) here.
https://github.com/FluxML/Zygote.jl/blob/531da8bb7753f46294bc13f9d2a2fdd54917f926/src/lib/broadcast.jl#L244
```
# Make sure sum(f, ::CuArray) uses forward mode broadcast AD defined above
# Not the ChainRules.rrule which will use the Zygote.Context and thus not be GPU safe
@adjoint function sum(f, xs::CuArray; kws...)
  @Assert !haskey(kws, :init) # TODO add init support (julia 1.6)
  return pullback(__context__, (f, xs) -> sum(f.(xs); kws...), f, xs)
end
```

In the longer-term, we will probably default to doing the f from sum(f, xs)  in forward-mode anyway.
So Zygote's rule config can be updated to say that it does use ForwardDiff.jl for it's frule_via_ad.

Co-authored-by: Lyndon White <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants