Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rule for sum(f, xs) #441

Merged
merged 16 commits into from
Jun 18, 2021
Merged

Rule for sum(f, xs) #441

merged 16 commits into from
Jun 18, 2021

Conversation

oxinabox
Copy link
Member

@oxinabox oxinabox commented Jun 8, 2021

This is the real use of JuliaDiff/ChainRulesCore.jl#363
needs tht to be merged first.
Soon I will make a Zygote PR that will to hit it.

Need to workout testing this

@github-actions github-actions bot added the needs version bump Version needs to be incremented or set to -DEV in Project.toml label Jun 8, 2021
@mzgubic
Copy link
Member

mzgubic commented Jun 9, 2021

Is there an issue with defining frule_via_ad(args...; kwargs...) = frule(args...; kwargs) for testing?

This needs defining the f/rrules for the f being passed to sum, but that's not too bad, right?

@oxinabox oxinabox changed the title WIP: rule for sum(f, xs) Rule for sum(f, xs) Jun 9, 2021
@oxinabox
Copy link
Member Author

oxinabox commented Jun 9, 2021

This works now.
And has tests.
But they don't use ChainRulesTestUtils
We need to address this.

@@ -19,6 +19,27 @@ function rrule(::typeof(sum), x::AbstractArray{T}; dims=:) where {T<:Number}
return y, sum_pullback
end

function rrule(
config::RuleConfig{>:HasReverseMode}, ::typeof(sum), f, xs::AbstractArray{T}; dims=:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
config::RuleConfig{>:HasReverseMode}, ::typeof(sum), f, xs::AbstractArray{T}; dims=:
config::RuleConfig{>:HasReverseMode}, ::typeof(sum), f, xs::Array{T}; dims=:

@willtebbutt's usual request?

Copy link
Member Author

@oxinabox oxinabox Jun 10, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

idk, I also want to define this on iterators more generally even.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also I want to delete the old ones from Zygote, and Zygote is always super-general

mzgubic pushed a commit to JuliaDiff/ChainRulesTestUtils.jl that referenced this pull request Jun 10, 2021
@github-actions github-actions bot removed the needs version bump Version needs to be incremented or set to -DEV in Project.toml label Jun 10, 2021
@oxinabox
Copy link
Member Author

oxinabox commented Jun 10, 2021

This is ready to go, once it's downstream
JuliaDiff/ChainRulesTestUtils.jl#166
and
JuliaDiff/ChainRulesCore.jl#363

are merged and tagged

Copy link
Member

@mzgubic mzgubic left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Needs compat change

@mzgubic
Copy link
Member

mzgubic commented Jun 11, 2021

This is going to need CRTU 0.7.9, JuliaRegistries/General#38672

Comment on lines 25 to 30
fx_and_pullbacks = map(x->rrule_via_ad(config, f, x), xs)
y = sum(first, fx_and_pullbacks; dims=dims)

pullbacks = last.(fx_and_pullbacks)
function sum_pullback(ȳ)
f̄_and_x̄s = [pullback(ȳ) for pullback in pullbacks]
Copy link
Member

@mcabbott mcabbott Jun 11, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have you investigated the performance of this? Storing an array of the need not be slow, but I wonder how well it works in practice?

I also wonder whether you could avoid making an array of tuples, before sum(first, fx_and_pullbacks) and last.(fx_and_pullbacks) separates them. For complete reductions it might not be hard to just update tot[] += ... within the map(x->rrule_via_ad loop.

Finally, have you given thought to weird arrays, like SMatrix, or CuArrays? For the former at least, perhaps using map instead of a generator for f̄_and_x̄s may better preserve the structure.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have you investigated the performance of this? Storing an array of the need not be slow, but I wonder how well it works in practice?

I have not. This is a first PR to get this feature out the door.
I am only really interested in it working, and having something to test AD's against.

For interest, Zygote's approach is
https://github.com/FluxML/Zygote.jl/blob/12f5c1d75eeaa8c7a818f2db7f8d082956c00cac/src/lib/array.jl#L296
which changes the primal to a sum(f.x) then AD's that expression.
Which hits https://github.com/FluxML/Zygote.jl/blob/12f5c1d75eeaa8c7a818f2db7f8d082956c00cac/src/lib/broadcast.jl#L172-L182
which is strictly worse than this, since it makes a temporary array for all the ys (which we just sum out), and for all the pullbacks

I will time them soon and post back, it will be interesting.

I also wonder whether you could avoid making an array of tuples, before sum(first, fx_and_pullbacks) and last.(fx_and_pullbacks) separates them.

Yeah, I would like to look into that.
It's why i am doing sum rather than for map because I think for map that would be even more important.
It's a fiddly little pattern to write to make sure everything gets the right types, so I didn't want to do it for the first PR.

Finally, have you given thought to weird arrays, like SMatrix, or CuArrays? For the former at least, perhaps using map instead of a generator for f̄_and_x̄s may better preserve the structure.

Good point, that's worth testing.

Copy link
Member

@mcabbott mcabbott Jun 11, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Without AD, the reason to write sum(f, x) instead of sum(f.(x)) is precisely to save allocations. It would be nice if the AD could preserve that, although possibly tricky.

I guess the size of the array of closures will depend on how much they capture, which may include both x and y = f.(x). Ideally it would never include x since that already has its own array... and sometimes it would be quicker to re-calculate y but that seems even harder to arrange. I suppose you could make it a user option, by declaring that sum(f, x) is always going to call f twice, once forwards, once back --- for use with low-cost f where the allocations matter. With high-cost f, there is little lost by calling sum(f.(x)).

fiddly little pattern to write to make sure everything gets the right types

Re where to write the sum on the forward pass, one possibility might be to hook into one of the later functions, maybe mapreduce!, when Julia has already made the array it's going to sum into.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did a little experiment into using mapreduce((net, pullbacks), (val, pb)) = (net + val, push!(pullbacks, pb))
and I didn't get any performance improvement. Infact I got a large regression, but that could have been me messing stuff up.
I would like to do more testing of this.

@oxinabox
Copy link
Member Author

Here are benchmarks
Script:

using Zygote
using BenchmarkTools

const xs = randn(10_000)
@btime Zygote.pullback(sum, abs, $xs)[2](1);

machine details

julia> versioninfo()
Julia Version 1.6.2-pre.0
Commit dd122918ce* (2021-04-23 21:21 UTC)
Platform Info:
  OS: macOS (x86_64-apple-darwin20.3.0)
  CPU: Intel(R) Core(TM) i7-8559U CPU @ 2.70GHz
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-11.0.1 (ORCJIT, skylake)

Zygote v0.6.12

  515.631 μs (30097 allocations: 1.07 MiB)

Zygote#ox/ruleconfig

using this PR as of 4207633

  55.476 μs (17 allocations: 703.61 KiB)

So while I don't think this code is the fastest code that can do this.
I don't this we need to worry about making things worst.
Zygote is just surfiently terrible right now

@mcabbott
Copy link
Member

FWIW, some other ways (although on a different machine!)

julia> @btime Zygote.pullback(sum, abs, $xs)[2](1);
  479.167 μs (30098 allocations: 1.07 MiB)

julia> using Tullio, ForwardDiff

julia> @btime Zygote.pullback(xs -> (@tullio _ := abs(xs[i])  avx=false threads=false), $xs)[2](1);
  27.375 μs (2 allocations: 78.20 KiB)

julia> @btime Zygote.pullback(xs -> (@tullio _ := abs(xs[i])  avx=false threads=false grad=Dual), $xs)[2](1);
  6.492 μs (2 allocations: 78.20 KiB)  # evaluates f twice, the second time with Dual(x,1) for gradient

julia> @btime copy($xs); # just to see alloc
  2.708 μs (2 allocations: 78.20 KiB)

@oxinabox
Copy link
Member Author

Yeah, I agree faster things are possible.
Can be follow up PRs, as long as we don't make Zygote worse, I am happy with this.

Cool (and I guess not surprising) that reverse via forward is fastest here.
Which is something we now support writing using RuleConfigs.

Comment on lines +80 to +112
# Fix dispatch for this pidgeon-hole optimization,
# Rules with RuleConfig dispatch with priority over without (regardless of other args).
# and if we don't specify what do do for one that HasReverseMode then it is ambigious
for Config in (RuleConfig, RuleConfig{>:HasReverseMode})
@eval function rrule(
::$Config, ::typeof(sum), ::typeof(abs2), x::AbstractArray{T}; dims=:,
) where {T<:Union{Real,Complex}}
return rrule(sum, abs2, x; dims=dims)
end
end
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't love this.
But I think it is as good as it gets.
We could require every rule is written to have to have a RuleConfig,
but we would still get the ambiguity, and so would still need the same code (just with one less thing in the top loop.

If we are happy with this I will make a follow up PR to the ChainRulesCore docs, warning about this.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Eww. But I can't think of anything nicer either. Can we add tests for this as well?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why would rrule(config::RuleConfig{&gt;:HasReverseMode}, ::typeof(sum), f::typeof(abs2), xs::AbstractArray) not be more specific?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is, which is why we have to generate that.

The problem is with:

rrule(config::RuleConfig, ::typeof(sum), f::typeof(abs2), xs::AbstractArray)

which is what we could make the required format.
That one is ambiguous with

rrule(config::RuleConfig{>:HasReverseMode}, ::typeof(sum), f, xs::AbstractArray)

@oxinabox oxinabox requested a review from mzgubic June 11, 2021 16:53
@codecov-commenter
Copy link

codecov-commenter commented Jun 11, 2021

Codecov Report

Merging #441 (0d1d140) into master (5c46ba5) will increase coverage by 0.01%.
The diff coverage is 100.00%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master     #441      +/-   ##
==========================================
+ Coverage   98.39%   98.40%   +0.01%     
==========================================
  Files          21       21              
  Lines        1989     2002      +13     
==========================================
+ Hits         1957     1970      +13     
  Misses         32       32              
Impacted Files Coverage Δ
src/rulesets/Base/mapreduce.jl 98.90% <100.00%> (+0.18%) ⬆️

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 5c46ba5...0d1d140. Read the comment docs.

@github-actions github-actions bot added the needs version bump Version needs to be incremented or set to -DEV in Project.toml label Jun 17, 2021
@github-actions github-actions bot removed the needs version bump Version needs to be incremented or set to -DEV in Project.toml label Jun 17, 2021
@oxinabox oxinabox requested a review from mzgubic June 17, 2021 19:21
Copy link
Member

@mzgubic mzgubic left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just a couple of minor things, lgtm otherwise, feel free to merge when you want

@@ -7,3 +7,4 @@ docs/build
docs/site
.idea/*
dev/*
.vscode/settings.json
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this intended?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it's not related to this PR though.

It's like how we have the idea editor config ignored

Comment on lines +55 to +59
f̄ = if fieldcount(typeof(f)) === 0 # Then don't need to worry about derivative wrt f
NoTangent()
else
sum(first, f̄_and_x̄s)
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
= if fieldcount(typeof(f)) === 0 # Then don't need to worry about derivative wrt f
NoTangent()
else
sum(first, f̄_and_x̄s)
end
= sum(first, f̄_and_x̄s)

This should also work, right?
It looks a lot cleaner to me.
I get that we can avoid summing a vector of NoTangent()s, but we have already allocated the vector so shouldn't be too much slower right?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I thought that so I timed it.
It's 20% slower.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh wow, I did not expect that much slower! What if we define sum(::Array{AbstractZero}) = ZeroTangent() or something similar?

Copy link
Member Author

@oxinabox oxinabox Jun 18, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we would need more like

sum(::typeof(first), ::Array{Tuple{T, Any}) where T isa AbstractZero = T()

which seems more involved than i want in my life.
Though it would address @simeonschaub 's concerns here #441 (comment)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah yeah, fair enough

call(f, x) = f(x) # we need to broadcast this to handle dims kwarg
f̄_and_x̄s = call.(pullbacks, ȳ)
# no point thunking as most of work is in f̄_and_x̄s which we need to compute for both
f̄ = if fieldcount(typeof(f)) === 0 # Then don't need to worry about derivative wrt f
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not necessarily related to this PR, but it would be good to define a function for fieldcount(typeof(f)) === 0, e.g. hasstructure(f) or iscomposite(f)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure that is necessarily the right check though. There could be number-like or array-like functors which do have a well-defined derivative wrt f, but don't have any fields. Probably not a big deal for now, but that might be something to keep in mind for general design decisions.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

true, they are very unusual though. Probably a good thing to worry about when defining @CarloLucibello 's iscomposite/hasstructure or what ever we call it, maybe hastangent?

Cases I can think of are functors that are also:

  • number types defined using primative,
  • Something that is like FillArrays.One() or OneHot that pushes size and index into the type-level.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tkf's suggestion was to use Base.issingletontype, which I think would avoid your primitive example. Not sure there's a One, but other cases with a value in the type are regarded as constants:

julia> FillArrays.Ones(3) |> typeof |> fieldnames
(:axes,)

julia> FillArrays.One
ERROR: UndefVarError: One not defined

julia> get2(::Val{x}) where x = x^2;

julia> gradient(get2∘Val, 3.14)
(nothing,)

julia> gradient(get2, Val(3.14))
(nothing,)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Xref weird example (from the tests) here: FluxML/Zygote.jl#1001 (comment)

tl;dr is that global s += 1 is not detected by these tests. That's a test of order of iteration.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants