-
-
Notifications
You must be signed in to change notification settings - Fork 610
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
PairwiseFusion
layer
#1971
PairwiseFusion
layer
#1971
Conversation
src/layers/basic.jl
Outdated
function (m::PairwiseFusion)(x::T) where {T} | ||
nlayers = length(m.layers) | ||
getinput(i) = T <: Tuple ? x[i] : x | ||
if T <: Tuple | ||
nx = length(x) | ||
if nx != nlayers | ||
throw(ArgumentError("PairwiseFusion with $nlayers layers takes $nlayers inputs, but got $nx inputs")) | ||
end | ||
end | ||
outputs = [m.layers[1](getinput(1))] | ||
for i in 2:nlayers | ||
push!(outputs, m.layers[i](m.connection(getinput(i), outputs[i - 1]))) | ||
end | ||
return outputs | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I fear this will be very AD unfriendly. Lux uses a generated function
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, this does seem to be throwing up an error when I try a gradient with a very simple example:
julia> x = (rand(1, 10), rand(30, 10));
julia> model = PairwiseFusion(+, Dense(1, 30), Dense(30, 10));
julia> @benchmark Zygote.gradient(p -> sum($model(p)), $x)
ERROR: DimensionMismatch: dimensions must match: a has dims (Base.OneTo(30), Base.OneTo(10)), b has dims (Base.OneTo(10), Base.OneTo(10)), mismatch at 1
Stacktrace:
[1] promote_shape
@ ./indices.jl:178 [inlined]
[2] promote_shape(a::Matrix{Float64}, b::Matrix{Float64})
@ Base ./indices.jl:169
[3] +(A::Matrix{Float64}, Bs::Matrix{Float64})
@ Base ./arraymath.jl:14
[4] adjoint
@ ~/.julia/packages/Zygote/DkIUK/src/lib/array.jl:769 [inlined]
[5] _pullback
@ ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:65 [inlined]
[6] _pullback
@ ~/Code/Flux.jl/src/layers/basic.jl:602 [inlined]
[7] _pullback(ctx::Zygote.Context, f::PairwiseFusion{typeof(+), Tuple{Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}, args::Tuple{Matrix{Float64}, Matrix{Float64}})
@ Zygote ~/.julia/packages/Zygote/DkIUK/src/compiler/interface2.jl:0
[8] _pullback
@ ./REPL[8]:1 [inlined]
[9] _pullback(ctx::Zygote.Context, f::var"#7#8"{PairwiseFusion{typeof(+), Tuple{Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}}, args::Tuple{Matrix{Float64}, Matrix{Float64}})
@ Zygote ~/.julia/packages/Zygote/DkIUK/src/compiler/interface2.jl:0
[10] _pullback(f::Function, args::Tuple{Matrix{Float64}, Matrix{Float64}})
@ Zygote ~/.julia/packages/Zygote/DkIUK/src/compiler/interface.jl:34
[11] pullback(f::Function, args::Tuple{Matrix{Float64}, Matrix{Float64}})
@ Zygote ~/.julia/packages/Zygote/DkIUK/src/compiler/interface.jl:40
[12] gradient(f::Function, args::Tuple{Matrix{Float64}, Matrix{Float64}})
@ Zygote ~/.julia/packages/Zygote/DkIUK/src/compiler/interface.jl:75
[13] var"##core#346"(model#341::PairwiseFusion{typeof(+), Tuple{Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}, x#342::Tuple{Matrix{Float64}, Matrix{Float64}})
@ Main ~/.julia/packages/BenchmarkTools/7xSXH/src/execution.jl:489
[14] var"##sample#347"(::Tuple{PairwiseFusion{typeof(+), Tuple{Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}, Tuple{Matrix{Float64}, Matrix{Float64}}}, __params::BenchmarkTools.Parameters)
@ Main ~/.julia/packages/BenchmarkTools/7xSXH/src/execution.jl:495
[15] _run(b::BenchmarkTools.Benchmark, p::BenchmarkTools.Parameters; verbose::Bool, pad::String, kwargs::Base.Pairs{Symbol, Integer, NTuple{4, Symbol}, NamedTuple{(:samples, :evals, :gctrial, :gcsample), Tuple{Int64, Int64, Bool, Bool}}})
@ BenchmarkTools ~/.julia/packages/BenchmarkTools/7xSXH/src/execution.jl:99
[16] #invokelatest#2
@ ./essentials.jl:801 [inlined]
[17] #run_result#45
@ ~/.julia/packages/BenchmarkTools/7xSXH/src/execution.jl:34 [inlined]
[18] run(b::BenchmarkTools.Benchmark, p::BenchmarkTools.Parameters; progressid::Nothing, nleaves::Float64, ndone::Float64, kwargs::Base.Pairs{Symbol, Integer, NTuple{5, Symbol}, NamedTuple{(:verbose, :samples, :evals, :gctrial, :gcsample), Tuple{Bool, Int64, Int64, Bool, Bool}}})
@ BenchmarkTools ~/.julia/packages/BenchmarkTools/7xSXH/src/execution.jl:117
[19] #warmup#54
@ ~/.julia/packages/BenchmarkTools/7xSXH/src/execution.jl:169 [inlined]
[20] warmup(item::BenchmarkTools.Benchmark)
@ BenchmarkTools ~/.julia/packages/BenchmarkTools/7xSXH/src/execution.jl:168
[21] top-level scope
@ ~/.julia/packages/BenchmarkTools/7xSXH/src/execution.jl:393
[22] top-level scope
@ ~/.julia/packages/CUDA/qAl31/src/initialization.jl:52
The forward pass works just fine, so not sure what's going wrong. I'll try with a generated function and see if that makes a difference
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would recommend design specifically for the case where x
is a tuple, then adding a method that turns non-tuples into tuples for the other cases.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So I've used a generated function - it's pretty much the same as the one in Lux, but that one stops at a slightly different stage (it stops after the combination of the output and the next input, meaning N layers actually end up requiring N + 1 inputs unlike what's advertised). The gradient works now, but it's consistently about 6 µs off the pace from Lux when I benchmark it - not really sure why
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also let's try to avoid generated function if possible, we have such an aestethically pleasant codebase...
The problem is array mutation. Maybe use comprehension instead?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
And as @darsnack suggested
(m::PairwiseFusion)(x) = m(ntuple(i -> x, length(m.layers)))
function (m::PairwiseFusion)(x::Tuple)
...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay, so I tried that, and it benchmarks more than 2x slower:
Without generated function
function (m::PairwiseFusion)(x::Tuple)
nlayers = length(m.layers)
nx = length(x)
if nx != nlayers
throw(ArgumentError("PairwiseFusion with $nlayers layers takes $nlayers inputs, but got $nx inputs"))
end
out(::Val{1}) = m.layers[1](x[1])
out(::Val{i}) where i = m.layers[i](m.connection(out(Val(i - 1)), x[i]))
outputs = [out(Val(i)) for i in 1:nlayers]
return outputs
end
Benchmark
julia> model = PairwiseFusion(+, Dense(1, 30), Dense(30, 10));
julia> x = (rand(1, 10), rand(30, 10));
julia> @benchmark Zygote.gradient(p -> sum($model(p)[1]), $x)
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
Range (min … max): 50.833 μs … 7.276 ms ┊ GC (min … max): 0.00% … 98.10%
Time (median): 55.209 μs ┊ GC (median): 0.00%
Time (mean ± σ): 60.031 μs ± 174.896 μs ┊ GC (mean ± σ): 7.05% ± 2.40%
▄█▇▇▇▃▁
▁▁▃▄▅▇▆▇▇█████████▆▅▅▄▄▃▂▂▂▂▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ ▃
50.8 μs Histogram: frequency by time 70.8 μs <
Memory estimate: 53.98 KiB, allocs estimate: 258.
Storing no intermittent inputs (i.e. just straightway calling out(nlayers)
instead of creating the outputs
array) brings this down marginally, but not enough:
julia> @benchmark Zygote.gradient(p -> sum($model(p)[1]), $x)
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
Range (min … max): 45.625 μs … 7.820 ms ┊ GC (min … max): 0.00% … 97.97%
Time (median): 48.916 μs ┊ GC (median): 0.00%
Time (mean ± σ): 52.350 μs ± 154.053 μs ┊ GC (mean ± σ): 5.80% ± 1.96%
▁▂▅▇▅█▆██▄▄▁▁
▁▁▁▂▂▂▃▄▆▆█████████████▇▆▅▄▄▃▃▃▂▂▂▂▂▁▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ ▃
45.6 μs Histogram: frequency by time 57.1 μs <
Memory estimate: 35.73 KiB, allocs estimate: 269.
With generated function
function (m::PairwiseFusion)(x::T) where {T}
lx = length(x)
N = length(m.layers)
if T <: Tuple && lx != N
throw(ArgumentError("PairwiseFusion with $N sub-layers can take one input or $N inputs, but got $lx inputs"))
end
applypairwisefusion(m.layers, m.connection, x)
end
@generated function applypairwisefusion(layers::NamedTuple{names}, connection, x::T) where {names, T}
N = length(names)
y_symbols = [gensym() for _ in 1:(N + 1)]
getinput(i) = T <: Tuple ? :(x[$i]) : :x
calls = [:($(y_symbols[N + 1]) = $(getinput(1)))]
append!(calls,
[:($(y_symbols[i]) = layers[$i]($(y_symbols[N + 1]));
$(y_symbols[N + 1]) = connection($(y_symbols[i]), $(getinput(i + 1))))
for i in 1:N - 1])
push!(calls, :($(y_symbols[N]) = layers[$N]($(y_symbols[N + 1]))))
push!(calls, :(return $(y_symbols[N])))
return Expr(:block, calls...)
end
Benchmark:
julia> model = PairwiseFusion(+, Dense(1, 30), Dense(30, 10));
julia> x = (rand(1, 10), rand(30, 10));
julia> @benchmark Zygote.gradient(p -> sum($model(p)[1]), $x)
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
Range (min … max): 22.042 μs … 4.498 ms ┊ GC (min … max): 0.00% … 97.52%
Time (median): 24.125 μs ┊ GC (median): 0.00%
Time (mean ± σ): 27.066 μs ± 107.292 μs ┊ GC (mean ± σ): 9.55% ± 2.40%
▃▅▅▆▆▆▇▄█▆▅▃▂
▁▂▃▅███████████████▇▆▅▄▃▃▃▃▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ ▃
22 μs Histogram: frequency by time 31.8 μs <
Memory estimate: 28.23 KiB, allocs estimate: 138.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
out(::Val{1}) = m.layers[1](x[1])
out(::Val{i}) where i = m.layers[i](m.connection(out(Val(i - 1)), x[i]))
outputs = [out(Val(i)) for i in 1:nlayers]
This doesn't reuse previous computation, so it scales quadratically with nlayers
. You want something like (didn't run the code so it may be buggy):
h = x[1]
out(::Val{1}) = m.layers[1](h)
out(::Val{i}) where i = m.layers[i](m.connection(h, x[i]))
outputs = [(h = out(Val(i)); h) for i in 1:nlayers]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
or just
h = x[1]
out(i) = i === 1 ? m.layers[1](h) : m.layers[i](m.connection(h, x[i]))
outputs = [(h = out(i); h) for i in 1:nlayers]
if performance is the same
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I actually tried that later (I forgot to add the benchmarks for it). It does better than the purely recursive version but it's still way off the pace of the generated function:
julia> @benchmark Zygote.gradient(p -> sum($model(p)[1]), $x)
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
Range (min … max): 41.458 μs … 6.534 ms ┊ GC (min … max): 0.00% … 96.79%
Time (median): 44.791 μs ┊ GC (median): 0.00%
Time (mean ± σ): 49.408 μs ± 140.782 μs ┊ GC (mean ± σ): 6.80% ± 2.37%
▃▅█▃
▁▂▂▄▅▇████▆▄▃▃▃▃▃▃▃▃▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ ▂
41.5 μs Histogram: frequency by time 64.3 μs <
Memory estimate: 49.86 KiB, allocs estimate: 354.
Second version has performance nearabouts, I just checked
So currently, this layer returns only the final output. Is there a way to write a return expression that can return the entire range of the y-values (i.e. |
Codecov Report
@@ Coverage Diff @@
## master #1971 +/- ##
==========================================
- Coverage 87.94% 86.37% -1.58%
==========================================
Files 19 19
Lines 1485 1512 +27
==========================================
Hits 1306 1306
- Misses 179 206 +27
Continue to review full report at Codecov.
|
@theabhirath see https://github.com/avik-pal/Lux.jl/blob/0ca5a265b7ef8c6d3da2b2dfacfefc56a39f5163/src/layers/basic.jl#L329 |
Add pretty printing
That did it, thank you! 😄 |
Add pretty printing
Also add NEWS.md entry
Bump on this? I've added a very basic test (one that occurred to me immediately, I'll find some more and add them later). It'll help me complete some of the Metalhead models that have been stalling (Res2Net, for example) |
if T <: Tuple && lx != N | ||
throw(ArgumentError("PairwiseFusion with $N sub-layers can take one input or $N inputs, but got $lx inputs")) | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you toss this in a function and mark it @non_differentiable
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done on the other PR
2. Any other kind of input: | ||
|
||
```julia | ||
y = x | ||
for i in 1:N | ||
y = connection(x, layers[i](y)) | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This will get a test, I presume?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added on the other PR
This PR implements a
PairwiseFusion
layer similar to the one in Lux. This might be especially useful for implementing models like Res2Net, since the forward pass for the bottleneck layer is quite similar to this.TODO:
Parallel
PR Checklist