-
Notifications
You must be signed in to change notification settings - Fork 90
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Rule for sum(f, xs) #441
Rule for sum(f, xs) #441
Conversation
Is there an issue with defining This needs defining the |
This works now. |
src/rulesets/Base/mapreduce.jl
Outdated
@@ -19,6 +19,27 @@ function rrule(::typeof(sum), x::AbstractArray{T}; dims=:) where {T<:Number} | |||
return y, sum_pullback | |||
end | |||
|
|||
function rrule( | |||
config::RuleConfig{>:HasReverseMode}, ::typeof(sum), f, xs::AbstractArray{T}; dims=: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
config::RuleConfig{>:HasReverseMode}, ::typeof(sum), f, xs::AbstractArray{T}; dims=: | |
config::RuleConfig{>:HasReverseMode}, ::typeof(sum), f, xs::Array{T}; dims=: |
@willtebbutt's usual request?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
idk, I also want to define this on iterators more generally even.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also I want to delete the old ones from Zygote, and Zygote is always super-general
This is ready to go, once it's downstream are merged and tagged |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Needs compat change
This is going to need CRTU 0.7.9, JuliaRegistries/General#38672 |
src/rulesets/Base/mapreduce.jl
Outdated
fx_and_pullbacks = map(x->rrule_via_ad(config, f, x), xs) | ||
y = sum(first, fx_and_pullbacks; dims=dims) | ||
|
||
pullbacks = last.(fx_and_pullbacks) | ||
function sum_pullback(ȳ) | ||
f̄_and_x̄s = [pullback(ȳ) for pullback in pullbacks] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Have you investigated the performance of this? Storing an array of the need not be slow, but I wonder how well it works in practice?
I also wonder whether you could avoid making an array of tuples, before sum(first, fx_and_pullbacks)
and last.(fx_and_pullbacks)
separates them. For complete reductions it might not be hard to just update tot[] += ...
within the map(x->rrule_via_ad
loop.
Finally, have you given thought to weird arrays, like SMatrix
, or CuArray
s? For the former at least, perhaps using map
instead of a generator for f̄_and_x̄s
may better preserve the structure.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Have you investigated the performance of this? Storing an array of the need not be slow, but I wonder how well it works in practice?
I have not. This is a first PR to get this feature out the door.
I am only really interested in it working, and having something to test AD's against.
For interest, Zygote's approach is
https://github.com/FluxML/Zygote.jl/blob/12f5c1d75eeaa8c7a818f2db7f8d082956c00cac/src/lib/array.jl#L296
which changes the primal to a sum(f.x)
then AD's that expression.
Which hits https://github.com/FluxML/Zygote.jl/blob/12f5c1d75eeaa8c7a818f2db7f8d082956c00cac/src/lib/broadcast.jl#L172-L182
which is strictly worse than this, since it makes a temporary array for all the y
s (which we just sum out), and for all the pullbacks
I will time them soon and post back, it will be interesting.
I also wonder whether you could avoid making an array of tuples, before
sum(first, fx_and_pullbacks)
andlast.(fx_and_pullbacks)
separates them.
Yeah, I would like to look into that.
It's why i am doing sum
rather than for map
because I think for map
that would be even more important.
It's a fiddly little pattern to write to make sure everything gets the right types, so I didn't want to do it for the first PR.
Finally, have you given thought to weird arrays, like SMatrix, or CuArrays? For the former at least, perhaps using map instead of a generator for f̄_and_x̄s may better preserve the structure.
Good point, that's worth testing.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Without AD, the reason to write sum(f, x)
instead of sum(f.(x))
is precisely to save allocations. It would be nice if the AD could preserve that, although possibly tricky.
I guess the size of the array of closures will depend on how much they capture, which may include both x
and y = f.(x)
. Ideally it would never include x
since that already has its own array... and sometimes it would be quicker to re-calculate y
but that seems even harder to arrange. I suppose you could make it a user option, by declaring that sum(f, x)
is always going to call f
twice, once forwards, once back --- for use with low-cost f
where the allocations matter. With high-cost f
, there is little lost by calling sum(f.(x))
.
fiddly little pattern to write to make sure everything gets the right types
Re where to write the sum on the forward pass, one possibility might be to hook into one of the later functions, maybe mapreduce!
, when Julia has already made the array it's going to sum into.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I did a little experiment into using mapreduce((net, pullbacks), (val, pb)) = (net + val, push!(pullbacks, pb))
and I didn't get any performance improvement. Infact I got a large regression, but that could have been me messing stuff up.
I would like to do more testing of this.
Here are benchmarks using Zygote
using BenchmarkTools
const xs = randn(10_000)
@btime Zygote.pullback(sum, abs, $xs)[2](1); machine details julia> versioninfo()
Julia Version 1.6.2-pre.0
Commit dd122918ce* (2021-04-23 21:21 UTC)
Platform Info:
OS: macOS (x86_64-apple-darwin20.3.0)
CPU: Intel(R) Core(TM) i7-8559U CPU @ 2.70GHz
WORD_SIZE: 64
LIBM: libopenlibm
LLVM: libLLVM-11.0.1 (ORCJIT, skylake) Zygote v0.6.12
Zygote#ox/ruleconfigusing this PR as of 4207633
So while I don't think this code is the fastest code that can do this. |
FWIW, some other ways (although on a different machine!)
|
Yeah, I agree faster things are possible. Cool (and I guess not surprising) that reverse via forward is fastest here. |
# Fix dispatch for this pidgeon-hole optimization, | ||
# Rules with RuleConfig dispatch with priority over without (regardless of other args). | ||
# and if we don't specify what do do for one that HasReverseMode then it is ambigious | ||
for Config in (RuleConfig, RuleConfig{>:HasReverseMode}) | ||
@eval function rrule( | ||
::$Config, ::typeof(sum), ::typeof(abs2), x::AbstractArray{T}; dims=:, | ||
) where {T<:Union{Real,Complex}} | ||
return rrule(sum, abs2, x; dims=dims) | ||
end | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't love this.
But I think it is as good as it gets.
We could require every rule is written to have to have a RuleConfig
,
but we would still get the ambiguity, and so would still need the same code (just with one less thing in the top loop.
If we are happy with this I will make a follow up PR to the ChainRulesCore docs, warning about this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Eww. But I can't think of anything nicer either. Can we add tests for this as well?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why would rrule(config::RuleConfig{>:HasReverseMode}, ::typeof(sum), f::typeof(abs2), xs::AbstractArray)
not be more specific?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is, which is why we have to generate that.
The problem is with:
rrule(config::RuleConfig, ::typeof(sum), f::typeof(abs2), xs::AbstractArray)
which is what we could make the required format.
That one is ambiguous with
rrule(config::RuleConfig{>:HasReverseMode}, ::typeof(sum), f, xs::AbstractArray)
Codecov Report
@@ Coverage Diff @@
## master #441 +/- ##
==========================================
+ Coverage 98.39% 98.40% +0.01%
==========================================
Files 21 21
Lines 1989 2002 +13
==========================================
+ Hits 1957 1970 +13
Misses 32 32
Continue to review full report at Codecov.
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just a couple of minor things, lgtm otherwise, feel free to merge when you want
@@ -7,3 +7,4 @@ docs/build | |||
docs/site | |||
.idea/* | |||
dev/* | |||
.vscode/settings.json |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this intended?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, it's not related to this PR though.
It's like how we have the idea
editor config ignored
f̄ = if fieldcount(typeof(f)) === 0 # Then don't need to worry about derivative wrt f | ||
NoTangent() | ||
else | ||
sum(first, f̄_and_x̄s) | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
f̄ = if fieldcount(typeof(f)) === 0 # Then don't need to worry about derivative wrt f | |
NoTangent() | |
else | |
sum(first, f̄_and_x̄s) | |
end | |
f̄ = sum(first, f̄_and_x̄s) |
This should also work, right?
It looks a lot cleaner to me.
I get that we can avoid summing a vector of NoTangent()
s, but we have already allocated the vector so shouldn't be too much slower right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah I thought that so I timed it.
It's 20% slower.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh wow, I did not expect that much slower! What if we define sum(::Array{AbstractZero}) = ZeroTangent()
or something similar?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we would need more like
sum(::typeof(first), ::Array{Tuple{T, Any}) where T isa AbstractZero = T()
which seems more involved than i want in my life.
Though it would address @simeonschaub 's concerns here #441 (comment)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ah yeah, fair enough
call(f, x) = f(x) # we need to broadcast this to handle dims kwarg | ||
f̄_and_x̄s = call.(pullbacks, ȳ) | ||
# no point thunking as most of work is in f̄_and_x̄s which we need to compute for both | ||
f̄ = if fieldcount(typeof(f)) === 0 # Then don't need to worry about derivative wrt f |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not necessarily related to this PR, but it would be good to define a function for fieldcount(typeof(f)) === 0
, e.g. hasstructure(f)
or iscomposite(f)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not sure that is necessarily the right check though. There could be number-like or array-like functors which do have a well-defined derivative wrt f
, but don't have any fields. Probably not a big deal for now, but that might be something to keep in mind for general design decisions.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
true, they are very unusual though. Probably a good thing to worry about when defining @CarloLucibello 's iscomposite
/hasstructure
or what ever we call it, maybe hastangent
?
Cases I can think of are functors that are also:
- number types defined using
primative
, - Something that is like
FillArrays.One()
orOneHot
that pushes size and index into the type-level.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@tkf's suggestion was to use Base.issingletontype
, which I think would avoid your primitive
example. Not sure there's a One, but other cases with a value in the type are regarded as constants:
julia> FillArrays.Ones(3) |> typeof |> fieldnames
(:axes,)
julia> FillArrays.One
ERROR: UndefVarError: One not defined
julia> get2(::Val{x}) where x = x^2;
julia> gradient(get2∘Val, 3.14)
(nothing,)
julia> gradient(get2, Val(3.14))
(nothing,)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Xref weird example (from the tests) here: FluxML/Zygote.jl#1001 (comment)
tl;dr is that global s += 1
is not detected by these tests. That's a test of order of iteration.
This is the real use of JuliaDiff/ChainRulesCore.jl#363
needs tht to be merged first.
Soon I will make a Zygote PR that will to hit it.
Need to workout testing this