Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PairwiseFusion layer #1971

Closed
wants to merge 13 commits into from
Closed

Conversation

theabhirath
Copy link
Member

@theabhirath theabhirath commented May 25, 2022

This PR implements a PairwiseFusion layer similar to the one in Lux. This might be especially useful for implementing models like Res2Net, since the forward pass for the bottleneck layer is quite similar to this.

TODO:

  • Add pretty-printing similar to Parallel

PR Checklist

  • Tests are added
  • Entry in NEWS.md
  • Documentation, if applicable

@theabhirath theabhirath changed the title Pairwise fusion PairwiseFusion layer May 25, 2022
src/layers/basic.jl Outdated Show resolved Hide resolved
Comment on lines 584 to 598
function (m::PairwiseFusion)(x::T) where {T}
nlayers = length(m.layers)
getinput(i) = T <: Tuple ? x[i] : x
if T <: Tuple
nx = length(x)
if nx != nlayers
throw(ArgumentError("PairwiseFusion with $nlayers layers takes $nlayers inputs, but got $nx inputs"))
end
end
outputs = [m.layers[1](getinput(1))]
for i in 2:nlayers
push!(outputs, m.layers[i](m.connection(getinput(i), outputs[i - 1])))
end
return outputs
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I fear this will be very AD unfriendly. Lux uses a generated function

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, this does seem to be throwing up an error when I try a gradient with a very simple example:

julia> x = (rand(1, 10), rand(30, 10));

julia> model = PairwiseFusion(+,  Dense(1, 30),  Dense(30, 10));

julia> @benchmark Zygote.gradient(p -> sum($model(p)), $x)
ERROR: DimensionMismatch: dimensions must match: a has dims (Base.OneTo(30), Base.OneTo(10)), b has dims (Base.OneTo(10), Base.OneTo(10)), mismatch at 1
Stacktrace:
  [1] promote_shape
    @ ./indices.jl:178 [inlined]
  [2] promote_shape(a::Matrix{Float64}, b::Matrix{Float64})
    @ Base ./indices.jl:169
  [3] +(A::Matrix{Float64}, Bs::Matrix{Float64})
    @ Base ./arraymath.jl:14
  [4] adjoint
    @ ~/.julia/packages/Zygote/DkIUK/src/lib/array.jl:769 [inlined]
  [5] _pullback
    @ ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:65 [inlined]
  [6] _pullback
    @ ~/Code/Flux.jl/src/layers/basic.jl:602 [inlined]
  [7] _pullback(ctx::Zygote.Context, f::PairwiseFusion{typeof(+), Tuple{Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}, args::Tuple{Matrix{Float64}, Matrix{Float64}})
    @ Zygote ~/.julia/packages/Zygote/DkIUK/src/compiler/interface2.jl:0
  [8] _pullback
    @ ./REPL[8]:1 [inlined]
  [9] _pullback(ctx::Zygote.Context, f::var"#7#8"{PairwiseFusion{typeof(+), Tuple{Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}}, args::Tuple{Matrix{Float64}, Matrix{Float64}})
    @ Zygote ~/.julia/packages/Zygote/DkIUK/src/compiler/interface2.jl:0
 [10] _pullback(f::Function, args::Tuple{Matrix{Float64}, Matrix{Float64}})
    @ Zygote ~/.julia/packages/Zygote/DkIUK/src/compiler/interface.jl:34
 [11] pullback(f::Function, args::Tuple{Matrix{Float64}, Matrix{Float64}})
    @ Zygote ~/.julia/packages/Zygote/DkIUK/src/compiler/interface.jl:40
 [12] gradient(f::Function, args::Tuple{Matrix{Float64}, Matrix{Float64}})
    @ Zygote ~/.julia/packages/Zygote/DkIUK/src/compiler/interface.jl:75
 [13] var"##core#346"(model#341::PairwiseFusion{typeof(+), Tuple{Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}, x#342::Tuple{Matrix{Float64}, Matrix{Float64}})
    @ Main ~/.julia/packages/BenchmarkTools/7xSXH/src/execution.jl:489
 [14] var"##sample#347"(::Tuple{PairwiseFusion{typeof(+), Tuple{Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}, Tuple{Matrix{Float64}, Matrix{Float64}}}, __params::BenchmarkTools.Parameters)
    @ Main ~/.julia/packages/BenchmarkTools/7xSXH/src/execution.jl:495
 [15] _run(b::BenchmarkTools.Benchmark, p::BenchmarkTools.Parameters; verbose::Bool, pad::String, kwargs::Base.Pairs{Symbol, Integer, NTuple{4, Symbol}, NamedTuple{(:samples, :evals, :gctrial, :gcsample), Tuple{Int64, Int64, Bool, Bool}}})
    @ BenchmarkTools ~/.julia/packages/BenchmarkTools/7xSXH/src/execution.jl:99
 [16] #invokelatest#2
    @ ./essentials.jl:801 [inlined]
 [17] #run_result#45
    @ ~/.julia/packages/BenchmarkTools/7xSXH/src/execution.jl:34 [inlined]
 [18] run(b::BenchmarkTools.Benchmark, p::BenchmarkTools.Parameters; progressid::Nothing, nleaves::Float64, ndone::Float64, kwargs::Base.Pairs{Symbol, Integer, NTuple{5, Symbol}, NamedTuple{(:verbose, :samples, :evals, :gctrial, :gcsample), Tuple{Bool, Int64, Int64, Bool, Bool}}})
    @ BenchmarkTools ~/.julia/packages/BenchmarkTools/7xSXH/src/execution.jl:117
 [19] #warmup#54
    @ ~/.julia/packages/BenchmarkTools/7xSXH/src/execution.jl:169 [inlined]
 [20] warmup(item::BenchmarkTools.Benchmark)
    @ BenchmarkTools ~/.julia/packages/BenchmarkTools/7xSXH/src/execution.jl:168
 [21] top-level scope
    @ ~/.julia/packages/BenchmarkTools/7xSXH/src/execution.jl:393
 [22] top-level scope
    @ ~/.julia/packages/CUDA/qAl31/src/initialization.jl:52

The forward pass works just fine, so not sure what's going wrong. I'll try with a generated function and see if that makes a difference

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would recommend design specifically for the case where x is a tuple, then adding a method that turns non-tuples into tuples for the other cases.

Copy link
Member Author

@theabhirath theabhirath May 26, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So I've used a generated function - it's pretty much the same as the one in Lux, but that one stops at a slightly different stage (it stops after the combination of the output and the next input, meaning N layers actually end up requiring N + 1 inputs unlike what's advertised). The gradient works now, but it's consistently about 6 µs off the pace from Lux when I benchmark it - not really sure why

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also let's try to avoid generated function if possible, we have such an aestethically pleasant codebase...

The problem is array mutation. Maybe use comprehension instead?

Copy link
Member

@CarloLucibello CarloLucibello May 26, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And as @darsnack suggested

(m::PairwiseFusion)(x) = m(ntuple(i -> x, length(m.layers)))

function (m::PairwiseFusion)(x::Tuple)
   ...

Copy link
Member Author

@theabhirath theabhirath May 26, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, so I tried that, and it benchmarks more than 2x slower:

Without generated function

function (m::PairwiseFusion)(x::Tuple)
  nlayers = length(m.layers)
  nx = length(x)
  if nx != nlayers
    throw(ArgumentError("PairwiseFusion with $nlayers layers takes $nlayers inputs, but got $nx inputs"))
  end
  out(::Val{1}) = m.layers[1](x[1])
  out(::Val{i}) where i = m.layers[i](m.connection(out(Val(i - 1)), x[i]))
  outputs = [out(Val(i)) for i in 1:nlayers]
  return outputs
end

Benchmark

julia> model = PairwiseFusion(+,  Dense(1, 30),  Dense(30, 10));

julia> x = (rand(1, 10), rand(30, 10));

julia> @benchmark Zygote.gradient(p -> sum($model(p)[1]), $x)
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range (min  max):  50.833 μs    7.276 ms  ┊ GC (min  max): 0.00%  98.10%
 Time  (median):     55.209 μs               ┊ GC (median):    0.00%
 Time  (mean ± σ):   60.031 μs ± 174.896 μs  ┊ GC (mean ± σ):  7.05% ±  2.40%

            ▄█▇▇▇▃▁
  ▁▁▃▄▅▇▆▇▇█████████▆▅▅▄▄▃▂▂▂▂▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ ▃
  50.8 μs         Histogram: frequency by time         70.8 μs <

 Memory estimate: 53.98 KiB, allocs estimate: 258.

Storing no intermittent inputs (i.e. just straightway calling out(nlayers) instead of creating the outputs array) brings this down marginally, but not enough:

julia> @benchmark Zygote.gradient(p -> sum($model(p)[1]), $x)
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range (min  max):  45.625 μs    7.820 ms  ┊ GC (min  max): 0.00%  97.97%
 Time  (median):     48.916 μs               ┊ GC (median):    0.00%
 Time  (mean ± σ):   52.350 μs ± 154.053 μs  ┊ GC (mean ± σ):  5.80% ±  1.96%

            ▁▂▅▇▅█▆██▄▄▁▁
  ▁▁▁▂▂▂▃▄▆▆█████████████▇▆▅▄▄▃▃▃▂▂▂▂▂▁▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ ▃
  45.6 μs         Histogram: frequency by time         57.1 μs <

 Memory estimate: 35.73 KiB, allocs estimate: 269.

With generated function

function (m::PairwiseFusion)(x::T) where {T}
  lx = length(x)
  N = length(m.layers)
  if T <: Tuple && lx != N
    throw(ArgumentError("PairwiseFusion with $N sub-layers can take one input or $N inputs, but got $lx inputs"))
  end
  applypairwisefusion(m.layers, m.connection, x)
end

@generated function applypairwisefusion(layers::NamedTuple{names}, connection, x::T) where {names, T}
  N = length(names)
  y_symbols = [gensym() for _ in 1:(N + 1)]
  getinput(i) = T <: Tuple ? :(x[$i]) : :x
  calls = [:($(y_symbols[N + 1]) = $(getinput(1)))]
  append!(calls,
            [:($(y_symbols[i]) = layers[$i]($(y_symbols[N + 1]));
               $(y_symbols[N + 1]) = connection($(y_symbols[i]), $(getinput(i + 1)))) 
             for i in 1:N - 1])
  push!(calls, :($(y_symbols[N]) = layers[$N]($(y_symbols[N + 1]))))
  push!(calls, :(return $(y_symbols[N])))
  return Expr(:block, calls...)
end

Benchmark:

julia> model = PairwiseFusion(+,  Dense(1, 30),  Dense(30, 10));

julia> x = (rand(1, 10), rand(30, 10));

julia> @benchmark Zygote.gradient(p -> sum($model(p)[1]), $x)
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range (min  max):  22.042 μs    4.498 ms  ┊ GC (min  max): 0.00%  97.52%
 Time  (median):     24.125 μs               ┊ GC (median):    0.00%
 Time  (mean ± σ):   27.066 μs ± 107.292 μs  ┊ GC (mean ± σ):  9.55% ±  2.40%

       ▃▅▅▆▆▆▇▄█▆▅▃▂
  ▁▂▃▅███████████████▇▆▅▄▃▃▃▃▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ ▃
  22 μs           Histogram: frequency by time         31.8 μs <

 Memory estimate: 28.23 KiB, allocs estimate: 138.

Copy link
Member

@CarloLucibello CarloLucibello May 26, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

out(::Val{1}) = m.layers[1](x[1])
out(::Val{i}) where i = m.layers[i](m.connection(out(Val(i - 1)), x[i]))
outputs = [out(Val(i)) for i in 1:nlayers]

This doesn't reuse previous computation, so it scales quadratically with nlayers. You want something like (didn't run the code so it may be buggy):

h = x[1]
out(::Val{1}) = m.layers[1](h)
out(::Val{i}) where i = m.layers[i](m.connection(h, x[i]))
outputs = [(h = out(Val(i)); h) for i in 1:nlayers]

Copy link
Member

@CarloLucibello CarloLucibello May 26, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

or just

h = x[1]
out(i) = i === 1 ? m.layers[1](h) : m.layers[i](m.connection(h, x[i]))
outputs = [(h = out(i); h) for i in 1:nlayers]

if performance is the same

Copy link
Member Author

@theabhirath theabhirath May 26, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I actually tried that later (I forgot to add the benchmarks for it). It does better than the purely recursive version but it's still way off the pace of the generated function:

julia> @benchmark Zygote.gradient(p -> sum($model(p)[1]), $x)
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range (min  max):  41.458 μs    6.534 ms  ┊ GC (min  max): 0.00%  96.79%
 Time  (median):     44.791 μs               ┊ GC (median):    0.00%
 Time  (mean ± σ):   49.408 μs ± 140.782 μs  ┊ GC (mean ± σ):  6.80% ±  2.37%

        ▃▅█▃
  ▁▂▂▄▅▇████▆▄▃▃▃▃▃▃▃▃▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ ▂
  41.5 μs         Histogram: frequency by time         64.3 μs <

 Memory estimate: 49.86 KiB, allocs estimate: 354.

Second version has performance nearabouts, I just checked

@theabhirath
Copy link
Member Author

theabhirath commented May 26, 2022

So currently, this layer returns only the final output. Is there a way to write a return expression that can return the entire range of the y-values (i.e. $[y_1 \ldots y_n]$)? I tried interpolating in different ways but I couldn't quite figure out how to do it 😅

@codecov-commenter
Copy link

codecov-commenter commented May 26, 2022

Codecov Report

Merging #1971 (0d631fc) into master (b978c90) will decrease coverage by 1.57%.
The diff coverage is 6.89%.

@@            Coverage Diff             @@
##           master    #1971      +/-   ##
==========================================
- Coverage   87.94%   86.37%   -1.58%     
==========================================
  Files          19       19              
  Lines        1485     1512      +27     
==========================================
  Hits         1306     1306              
- Misses        179      206      +27     
Impacted Files Coverage Δ
src/layers/show.jl 71.79% <0.00%> (-0.94%) ⬇️
src/layers/basic.jl 68.35% <7.14%> (-13.47%) ⬇️

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update b978c90...0d631fc. Read the comment docs.

@avik-pal
Copy link
Member

So currently, this layer returns only the final output. Is there a way to write a return expression that can return the entire range of the y-values (i.e. $[y_1 \ldots y_n]$)? I tried interpolating in different ways but I couldn't quite figure out how to do it 😅

@theabhirath see https://github.com/avik-pal/Lux.jl/blob/0ca5a265b7ef8c6d3da2b2dfacfefc56a39f5163/src/layers/basic.jl#L329

@theabhirath
Copy link
Member Author

So currently, this layer returns only the final output. Is there a way to write a return expression that can return the entire range of the y-values (i.e. [y1…yn])? I tried interpolating in different ways but I couldn't quite figure out how to do it 😅

@theabhirath see https://github.com/avik-pal/Lux.jl/blob/0ca5a265b7ef8c6d3da2b2dfacfefc56a39f5163/src/layers/basic.jl#L329

That did it, thank you! 😄

@theabhirath
Copy link
Member Author

Bump on this? I've added a very basic test (one that occurred to me immediately, I'll find some more and add them later). It'll help me complete some of the Metalhead models that have been stalling (Res2Net, for example)

src/layers/basic.jl Show resolved Hide resolved
Comment on lines +582 to +584
if T <: Tuple && lx != N
throw(ArgumentError("PairwiseFusion with $N sub-layers can take one input or $N inputs, but got $lx inputs"))
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you toss this in a function and mark it @non_differentiable?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done on the other PR

src/layers/basic.jl Show resolved Hide resolved
Comment on lines +556 to +562
2. Any other kind of input:

```julia
y = x
for i in 1:N
y = connection(x, layers[i](y))
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will get a test, I presume?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added on the other PR

@theabhirath theabhirath mentioned this pull request Jun 2, 2022
4 tasks
@theabhirath theabhirath deleted the pairwise-fusion branch June 6, 2022 16:23
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants