Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PairwiseFusion layer #1971

Closed
wants to merge 13 commits into from
2 changes: 1 addition & 1 deletion src/Flux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ using Zygote, ChainRulesCore
using Zygote: Params, @adjoint, gradient, pullback, @nograd
export gradient

export Chain, Dense, Maxout, SkipConnection, Parallel,
export Chain, Dense, Maxout, SkipConnection, Parallel, PairwiseFusion,
RNN, LSTM, GRU, GRUv3,
SamePad, Conv, CrossCor, ConvTranspose, DepthwiseConv,
AdaptiveMaxPool, AdaptiveMeanPool, GlobalMaxPool, GlobalMeanPool, MaxPool, MeanPool,
Expand Down
96 changes: 91 additions & 5 deletions src/layers/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -487,7 +487,7 @@ end
Parallel(connection, layers...) = Parallel(connection, layers)
function Parallel(connection; kw...)
layers = NamedTuple(kw)
if :layers in Base.keys(layers) || :connection in Base.keys(layers)
if :layers in keys(layers) || :connection in keys(layers)
throw(ArgumentError("a Parallel layer cannot have a named sub-layer called `connection` or `layers`"))
end
isempty(layers) && return Parallel(connection, ())
Expand All @@ -510,16 +510,102 @@ end
Base.getindex(m::Parallel, i) = m.layers[i]
Base.getindex(m::Parallel, i::AbstractVector) = Parallel(m.connection, m.layers[i])
Base.getindex(m::Parallel{<:Any, <:NamedTuple}, i::AbstractVector) =
Parallel(m.connection, NamedTuple{Base.keys(m)[i]}(Tuple(m.layers)[i]))
Parallel(m.connection, NamedTuple{keys(m)[i]}(Tuple(m.layers)[i]))

Base.keys(m::Parallel) = Base.keys(getfield(m, :layers))
Base.keys(m::Parallel) = keys(getfield(m, :layers))

function Base.show(io::IO, m::Parallel)
print(io, "Parallel(", m.connection, ", ")
_show_layers(io, m.layers)
print(io, ")")
end

"""
PairwiseFusion(connection, layers...)

```
x1 --> layer1 --> y1
|
|--> connection --> layer2 --> y2
| |
x2 |--> connection --> layer3 --> y3
| |
x3 |--> connection --> y4
|
x4
```

## Arguments

- `connection`: Takes 2 inputs and combines them
- `layers`: The layers whose outputs are combined

## Inputs

This layer behaves differently based on input type:

1. Input `x` is a tuple of length `N`. Then `layers` must be a tuple of length `N`. The computation is as follows:

```julia
y = x[1]
for i in 1:N
y = connection(x[i], layers[i](y))
end
```

2. Any other kind of input:

```julia
y = x
for i in 1:N
y = connection(x, layers[i](y))
end
Comment on lines +556 to +562
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will get a test, I presume?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added on the other PR

```

## Returns

`PairwiseFusion` returns a tuple of length `N` with the output of each fusion ((`y1`, `y2`, ..., `yN`) in the example above).
theabhirath marked this conversation as resolved.
Show resolved Hide resolved
"""
struct PairwiseFusion{F, T<:Union{Tuple, NamedTuple}}
connection::F
layers::T
end

PairwiseFusion(connection, layers...) = PairwiseFusion(connection, layers)
function PairwiseFusion(connection; kw...)
layers = NamedTuple(kw)
if :layers in keys(layers) || :connection in keys(layers)
throw(ArgumentError("a Parallel layer cannot have a named sub-layer called `connection` or `layers`"))
theabhirath marked this conversation as resolved.
Show resolved Hide resolved
end
isempty(layers) && return Parallel(connection, ())
return PairwiseFusion(connection, layers)
end

function (m::PairwiseFusion)(x::T) where {T}
nlayers = length(m.layers)
getinput(i) = T <: Tuple ? x[i] : x
if T <: Tuple
nx = length(x)
if nx != nlayers
throw(ArgumentError("PairwiseFusion with $nlayers layers takes $nlayers inputs, but got $nx inputs"))
end
end
outputs = [m.layers[1](getinput(1))]
for i in 2:nlayers
push!(outputs, m.layers[i](m.connection(getinput(i), outputs[i - 1])))
end
return outputs
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I fear this will be very AD unfriendly. Lux uses a generated function

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, this does seem to be throwing up an error when I try a gradient with a very simple example:

julia> x = (rand(1, 10), rand(30, 10));

julia> model = PairwiseFusion(+,  Dense(1, 30),  Dense(30, 10));

julia> @benchmark Zygote.gradient(p -> sum($model(p)), $x)
ERROR: DimensionMismatch: dimensions must match: a has dims (Base.OneTo(30), Base.OneTo(10)), b has dims (Base.OneTo(10), Base.OneTo(10)), mismatch at 1
Stacktrace:
  [1] promote_shape
    @ ./indices.jl:178 [inlined]
  [2] promote_shape(a::Matrix{Float64}, b::Matrix{Float64})
    @ Base ./indices.jl:169
  [3] +(A::Matrix{Float64}, Bs::Matrix{Float64})
    @ Base ./arraymath.jl:14
  [4] adjoint
    @ ~/.julia/packages/Zygote/DkIUK/src/lib/array.jl:769 [inlined]
  [5] _pullback
    @ ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:65 [inlined]
  [6] _pullback
    @ ~/Code/Flux.jl/src/layers/basic.jl:602 [inlined]
  [7] _pullback(ctx::Zygote.Context, f::PairwiseFusion{typeof(+), Tuple{Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}, args::Tuple{Matrix{Float64}, Matrix{Float64}})
    @ Zygote ~/.julia/packages/Zygote/DkIUK/src/compiler/interface2.jl:0
  [8] _pullback
    @ ./REPL[8]:1 [inlined]
  [9] _pullback(ctx::Zygote.Context, f::var"#7#8"{PairwiseFusion{typeof(+), Tuple{Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}}, args::Tuple{Matrix{Float64}, Matrix{Float64}})
    @ Zygote ~/.julia/packages/Zygote/DkIUK/src/compiler/interface2.jl:0
 [10] _pullback(f::Function, args::Tuple{Matrix{Float64}, Matrix{Float64}})
    @ Zygote ~/.julia/packages/Zygote/DkIUK/src/compiler/interface.jl:34
 [11] pullback(f::Function, args::Tuple{Matrix{Float64}, Matrix{Float64}})
    @ Zygote ~/.julia/packages/Zygote/DkIUK/src/compiler/interface.jl:40
 [12] gradient(f::Function, args::Tuple{Matrix{Float64}, Matrix{Float64}})
    @ Zygote ~/.julia/packages/Zygote/DkIUK/src/compiler/interface.jl:75
 [13] var"##core#346"(model#341::PairwiseFusion{typeof(+), Tuple{Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}, x#342::Tuple{Matrix{Float64}, Matrix{Float64}})
    @ Main ~/.julia/packages/BenchmarkTools/7xSXH/src/execution.jl:489
 [14] var"##sample#347"(::Tuple{PairwiseFusion{typeof(+), Tuple{Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}, Tuple{Matrix{Float64}, Matrix{Float64}}}, __params::BenchmarkTools.Parameters)
    @ Main ~/.julia/packages/BenchmarkTools/7xSXH/src/execution.jl:495
 [15] _run(b::BenchmarkTools.Benchmark, p::BenchmarkTools.Parameters; verbose::Bool, pad::String, kwargs::Base.Pairs{Symbol, Integer, NTuple{4, Symbol}, NamedTuple{(:samples, :evals, :gctrial, :gcsample), Tuple{Int64, Int64, Bool, Bool}}})
    @ BenchmarkTools ~/.julia/packages/BenchmarkTools/7xSXH/src/execution.jl:99
 [16] #invokelatest#2
    @ ./essentials.jl:801 [inlined]
 [17] #run_result#45
    @ ~/.julia/packages/BenchmarkTools/7xSXH/src/execution.jl:34 [inlined]
 [18] run(b::BenchmarkTools.Benchmark, p::BenchmarkTools.Parameters; progressid::Nothing, nleaves::Float64, ndone::Float64, kwargs::Base.Pairs{Symbol, Integer, NTuple{5, Symbol}, NamedTuple{(:verbose, :samples, :evals, :gctrial, :gcsample), Tuple{Bool, Int64, Int64, Bool, Bool}}})
    @ BenchmarkTools ~/.julia/packages/BenchmarkTools/7xSXH/src/execution.jl:117
 [19] #warmup#54
    @ ~/.julia/packages/BenchmarkTools/7xSXH/src/execution.jl:169 [inlined]
 [20] warmup(item::BenchmarkTools.Benchmark)
    @ BenchmarkTools ~/.julia/packages/BenchmarkTools/7xSXH/src/execution.jl:168
 [21] top-level scope
    @ ~/.julia/packages/BenchmarkTools/7xSXH/src/execution.jl:393
 [22] top-level scope
    @ ~/.julia/packages/CUDA/qAl31/src/initialization.jl:52

The forward pass works just fine, so not sure what's going wrong. I'll try with a generated function and see if that makes a difference

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would recommend design specifically for the case where x is a tuple, then adding a method that turns non-tuples into tuples for the other cases.

Copy link
Member Author

@theabhirath theabhirath May 26, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So I've used a generated function - it's pretty much the same as the one in Lux, but that one stops at a slightly different stage (it stops after the combination of the output and the next input, meaning N layers actually end up requiring N + 1 inputs unlike what's advertised). The gradient works now, but it's consistently about 6 µs off the pace from Lux when I benchmark it - not really sure why

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also let's try to avoid generated function if possible, we have such an aestethically pleasant codebase...

The problem is array mutation. Maybe use comprehension instead?

Copy link
Member

@CarloLucibello CarloLucibello May 26, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And as @darsnack suggested

(m::PairwiseFusion)(x) = m(ntuple(i -> x, length(m.layers)))

function (m::PairwiseFusion)(x::Tuple)
   ...

Copy link
Member Author

@theabhirath theabhirath May 26, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, so I tried that, and it benchmarks more than 2x slower:

Without generated function

function (m::PairwiseFusion)(x::Tuple)
  nlayers = length(m.layers)
  nx = length(x)
  if nx != nlayers
    throw(ArgumentError("PairwiseFusion with $nlayers layers takes $nlayers inputs, but got $nx inputs"))
  end
  out(::Val{1}) = m.layers[1](x[1])
  out(::Val{i}) where i = m.layers[i](m.connection(out(Val(i - 1)), x[i]))
  outputs = [out(Val(i)) for i in 1:nlayers]
  return outputs
end

Benchmark

julia> model = PairwiseFusion(+,  Dense(1, 30),  Dense(30, 10));

julia> x = (rand(1, 10), rand(30, 10));

julia> @benchmark Zygote.gradient(p -> sum($model(p)[1]), $x)
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range (min  max):  50.833 μs    7.276 ms  ┊ GC (min  max): 0.00%  98.10%
 Time  (median):     55.209 μs               ┊ GC (median):    0.00%
 Time  (mean ± σ):   60.031 μs ± 174.896 μs  ┊ GC (mean ± σ):  7.05% ±  2.40%

            ▄█▇▇▇▃▁
  ▁▁▃▄▅▇▆▇▇█████████▆▅▅▄▄▃▂▂▂▂▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ ▃
  50.8 μs         Histogram: frequency by time         70.8 μs <

 Memory estimate: 53.98 KiB, allocs estimate: 258.

Storing no intermittent inputs (i.e. just straightway calling out(nlayers) instead of creating the outputs array) brings this down marginally, but not enough:

julia> @benchmark Zygote.gradient(p -> sum($model(p)[1]), $x)
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range (min  max):  45.625 μs    7.820 ms  ┊ GC (min  max): 0.00%  97.97%
 Time  (median):     48.916 μs               ┊ GC (median):    0.00%
 Time  (mean ± σ):   52.350 μs ± 154.053 μs  ┊ GC (mean ± σ):  5.80% ±  1.96%

            ▁▂▅▇▅█▆██▄▄▁▁
  ▁▁▁▂▂▂▃▄▆▆█████████████▇▆▅▄▄▃▃▃▂▂▂▂▂▁▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ ▃
  45.6 μs         Histogram: frequency by time         57.1 μs <

 Memory estimate: 35.73 KiB, allocs estimate: 269.

With generated function

function (m::PairwiseFusion)(x::T) where {T}
  lx = length(x)
  N = length(m.layers)
  if T <: Tuple && lx != N
    throw(ArgumentError("PairwiseFusion with $N sub-layers can take one input or $N inputs, but got $lx inputs"))
  end
  applypairwisefusion(m.layers, m.connection, x)
end

@generated function applypairwisefusion(layers::NamedTuple{names}, connection, x::T) where {names, T}
  N = length(names)
  y_symbols = [gensym() for _ in 1:(N + 1)]
  getinput(i) = T <: Tuple ? :(x[$i]) : :x
  calls = [:($(y_symbols[N + 1]) = $(getinput(1)))]
  append!(calls,
            [:($(y_symbols[i]) = layers[$i]($(y_symbols[N + 1]));
               $(y_symbols[N + 1]) = connection($(y_symbols[i]), $(getinput(i + 1)))) 
             for i in 1:N - 1])
  push!(calls, :($(y_symbols[N]) = layers[$N]($(y_symbols[N + 1]))))
  push!(calls, :(return $(y_symbols[N])))
  return Expr(:block, calls...)
end

Benchmark:

julia> model = PairwiseFusion(+,  Dense(1, 30),  Dense(30, 10));

julia> x = (rand(1, 10), rand(30, 10));

julia> @benchmark Zygote.gradient(p -> sum($model(p)[1]), $x)
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range (min  max):  22.042 μs    4.498 ms  ┊ GC (min  max): 0.00%  97.52%
 Time  (median):     24.125 μs               ┊ GC (median):    0.00%
 Time  (mean ± σ):   27.066 μs ± 107.292 μs  ┊ GC (mean ± σ):  9.55% ±  2.40%

       ▃▅▅▆▆▆▇▄█▆▅▃▂
  ▁▂▃▅███████████████▇▆▅▄▃▃▃▃▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ ▃
  22 μs           Histogram: frequency by time         31.8 μs <

 Memory estimate: 28.23 KiB, allocs estimate: 138.

Copy link
Member

@CarloLucibello CarloLucibello May 26, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

out(::Val{1}) = m.layers[1](x[1])
out(::Val{i}) where i = m.layers[i](m.connection(out(Val(i - 1)), x[i]))
outputs = [out(Val(i)) for i in 1:nlayers]

This doesn't reuse previous computation, so it scales quadratically with nlayers. You want something like (didn't run the code so it may be buggy):

h = x[1]
out(::Val{1}) = m.layers[1](h)
out(::Val{i}) where i = m.layers[i](m.connection(h, x[i]))
outputs = [(h = out(Val(i)); h) for i in 1:nlayers]

Copy link
Member

@CarloLucibello CarloLucibello May 26, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

or just

h = x[1]
out(i) = i === 1 ? m.layers[1](h) : m.layers[i](m.connection(h, x[i]))
outputs = [(h = out(i); h) for i in 1:nlayers]

if performance is the same

Copy link
Member Author

@theabhirath theabhirath May 26, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I actually tried that later (I forgot to add the benchmarks for it). It does better than the purely recursive version but it's still way off the pace of the generated function:

julia> @benchmark Zygote.gradient(p -> sum($model(p)[1]), $x)
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range (min  max):  41.458 μs    6.534 ms  ┊ GC (min  max): 0.00%  96.79%
 Time  (median):     44.791 μs               ┊ GC (median):    0.00%
 Time  (mean ± σ):   49.408 μs ± 140.782 μs  ┊ GC (mean ± σ):  6.80% ±  2.37%

        ▃▅█▃
  ▁▂▂▄▅▇████▆▄▃▃▃▃▃▃▃▃▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ ▂
  41.5 μs         Histogram: frequency by time         64.3 μs <

 Memory estimate: 49.86 KiB, allocs estimate: 354.

Second version has performance nearabouts, I just checked


@functor PairwiseFusion

Base.getindex(m::PairwiseFusion, i) = m.layers[i]
Base.getindex(m::PairwiseFusion, i::AbstractVector) = PairwiseFusion(m.connection, m.layers[i])
Base.getindex(m::PairwiseFusion{<:Any, <:NamedTuple}, i::AbstractVector) =
PairwiseFusion(m.connection, NamedTuple{keys(m)[i]}(Tuple(m.layers)[i]))

Base.keys(m::PairwiseFusion) = keys(getfield(m, :layers))

"""
Embedding(in => out; init=randn)

Expand Down Expand Up @@ -556,7 +642,7 @@ end
@functor Embedding

Embedding((in, out)::Pair{<:Integer, <:Integer}; init = randn32) = Embedding(init(out, in))

(m::Embedding)(x::Integer) = m.weight[:, x]
(m::Embedding)(x::AbstractVector) = NNlib.gather(m.weight, x)
(m::Embedding)(x::AbstractArray) = reshape(m(vec(x)), :, size(x)...)
Expand All @@ -565,7 +651,7 @@ function (m::Embedding)(x::Union{OneHotVector{T,L}, OneHotMatrix{T,L}}) where {T
size(m.weight, 2) == L || throw(DimensionMismatch("Matrix column must correspond with OneHot size: $(size(m.weight, 2)) != $L"))
return m(onecold(x))
end

function Base.show(io::IO, m::Embedding)
print(io, "Embedding(", size(m.weight, 2), " => ", size(m.weight, 1), ")")
end