-
-
Notifications
You must be signed in to change notification settings - Fork 611
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
PairwiseFusion
layer
#1971
PairwiseFusion
layer
#1971
Changes from 3 commits
7ef6b5e
e2bf575
abbccd7
d0bb044
0d631fc
309e15c
575c54f
c31cc47
e025fa5
85a2571
c30293a
4d46878
d5cfd2f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -487,7 +487,7 @@ end | |
Parallel(connection, layers...) = Parallel(connection, layers) | ||
function Parallel(connection; kw...) | ||
layers = NamedTuple(kw) | ||
if :layers in Base.keys(layers) || :connection in Base.keys(layers) | ||
if :layers in keys(layers) || :connection in keys(layers) | ||
throw(ArgumentError("a Parallel layer cannot have a named sub-layer called `connection` or `layers`")) | ||
end | ||
isempty(layers) && return Parallel(connection, ()) | ||
|
@@ -510,16 +510,102 @@ end | |
Base.getindex(m::Parallel, i) = m.layers[i] | ||
Base.getindex(m::Parallel, i::AbstractVector) = Parallel(m.connection, m.layers[i]) | ||
Base.getindex(m::Parallel{<:Any, <:NamedTuple}, i::AbstractVector) = | ||
Parallel(m.connection, NamedTuple{Base.keys(m)[i]}(Tuple(m.layers)[i])) | ||
Parallel(m.connection, NamedTuple{keys(m)[i]}(Tuple(m.layers)[i])) | ||
|
||
Base.keys(m::Parallel) = Base.keys(getfield(m, :layers)) | ||
Base.keys(m::Parallel) = keys(getfield(m, :layers)) | ||
|
||
function Base.show(io::IO, m::Parallel) | ||
print(io, "Parallel(", m.connection, ", ") | ||
_show_layers(io, m.layers) | ||
print(io, ")") | ||
end | ||
|
||
""" | ||
PairwiseFusion(connection, layers...) | ||
|
||
``` | ||
x1 --> layer1 --> y1 | ||
| | ||
|--> connection --> layer2 --> y2 | ||
| | | ||
x2 |--> connection --> layer3 --> y3 | ||
| | | ||
x3 |--> connection --> y4 | ||
| | ||
x4 | ||
``` | ||
|
||
## Arguments | ||
|
||
- `connection`: Takes 2 inputs and combines them | ||
- `layers`: The layers whose outputs are combined | ||
|
||
## Inputs | ||
|
||
This layer behaves differently based on input type: | ||
|
||
1. Input `x` is a tuple of length `N`. Then `layers` must be a tuple of length `N`. The computation is as follows: | ||
|
||
```julia | ||
y = x[1] | ||
for i in 1:N | ||
y = connection(x[i], layers[i](y)) | ||
end | ||
``` | ||
|
||
2. Any other kind of input: | ||
|
||
```julia | ||
y = x | ||
for i in 1:N | ||
y = connection(x, layers[i](y)) | ||
end | ||
``` | ||
|
||
## Returns | ||
|
||
`PairwiseFusion` returns a tuple of length `N` with the output of each fusion ((`y1`, `y2`, ..., `yN`) in the example above). | ||
theabhirath marked this conversation as resolved.
Show resolved
Hide resolved
|
||
""" | ||
struct PairwiseFusion{F, T<:Union{Tuple, NamedTuple}} | ||
connection::F | ||
layers::T | ||
end | ||
|
||
PairwiseFusion(connection, layers...) = PairwiseFusion(connection, layers) | ||
function PairwiseFusion(connection; kw...) | ||
layers = NamedTuple(kw) | ||
if :layers in keys(layers) || :connection in keys(layers) | ||
throw(ArgumentError("a Parallel layer cannot have a named sub-layer called `connection` or `layers`")) | ||
theabhirath marked this conversation as resolved.
Show resolved
Hide resolved
|
||
end | ||
isempty(layers) && return Parallel(connection, ()) | ||
return PairwiseFusion(connection, layers) | ||
end | ||
|
||
function (m::PairwiseFusion)(x::T) where {T} | ||
nlayers = length(m.layers) | ||
getinput(i) = T <: Tuple ? x[i] : x | ||
if T <: Tuple | ||
nx = length(x) | ||
if nx != nlayers | ||
throw(ArgumentError("PairwiseFusion with $nlayers layers takes $nlayers inputs, but got $nx inputs")) | ||
end | ||
end | ||
outputs = [m.layers[1](getinput(1))] | ||
for i in 2:nlayers | ||
push!(outputs, m.layers[i](m.connection(getinput(i), outputs[i - 1]))) | ||
end | ||
return outputs | ||
end | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I fear this will be very AD unfriendly. Lux uses a generated function There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmm, this does seem to be throwing up an error when I try a gradient with a very simple example: julia> x = (rand(1, 10), rand(30, 10));
julia> model = PairwiseFusion(+, Dense(1, 30), Dense(30, 10));
julia> @benchmark Zygote.gradient(p -> sum($model(p)), $x)
ERROR: DimensionMismatch: dimensions must match: a has dims (Base.OneTo(30), Base.OneTo(10)), b has dims (Base.OneTo(10), Base.OneTo(10)), mismatch at 1
Stacktrace:
[1] promote_shape
@ ./indices.jl:178 [inlined]
[2] promote_shape(a::Matrix{Float64}, b::Matrix{Float64})
@ Base ./indices.jl:169
[3] +(A::Matrix{Float64}, Bs::Matrix{Float64})
@ Base ./arraymath.jl:14
[4] adjoint
@ ~/.julia/packages/Zygote/DkIUK/src/lib/array.jl:769 [inlined]
[5] _pullback
@ ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:65 [inlined]
[6] _pullback
@ ~/Code/Flux.jl/src/layers/basic.jl:602 [inlined]
[7] _pullback(ctx::Zygote.Context, f::PairwiseFusion{typeof(+), Tuple{Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}, args::Tuple{Matrix{Float64}, Matrix{Float64}})
@ Zygote ~/.julia/packages/Zygote/DkIUK/src/compiler/interface2.jl:0
[8] _pullback
@ ./REPL[8]:1 [inlined]
[9] _pullback(ctx::Zygote.Context, f::var"#7#8"{PairwiseFusion{typeof(+), Tuple{Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}}, args::Tuple{Matrix{Float64}, Matrix{Float64}})
@ Zygote ~/.julia/packages/Zygote/DkIUK/src/compiler/interface2.jl:0
[10] _pullback(f::Function, args::Tuple{Matrix{Float64}, Matrix{Float64}})
@ Zygote ~/.julia/packages/Zygote/DkIUK/src/compiler/interface.jl:34
[11] pullback(f::Function, args::Tuple{Matrix{Float64}, Matrix{Float64}})
@ Zygote ~/.julia/packages/Zygote/DkIUK/src/compiler/interface.jl:40
[12] gradient(f::Function, args::Tuple{Matrix{Float64}, Matrix{Float64}})
@ Zygote ~/.julia/packages/Zygote/DkIUK/src/compiler/interface.jl:75
[13] var"##core#346"(model#341::PairwiseFusion{typeof(+), Tuple{Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}, x#342::Tuple{Matrix{Float64}, Matrix{Float64}})
@ Main ~/.julia/packages/BenchmarkTools/7xSXH/src/execution.jl:489
[14] var"##sample#347"(::Tuple{PairwiseFusion{typeof(+), Tuple{Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}, Tuple{Matrix{Float64}, Matrix{Float64}}}, __params::BenchmarkTools.Parameters)
@ Main ~/.julia/packages/BenchmarkTools/7xSXH/src/execution.jl:495
[15] _run(b::BenchmarkTools.Benchmark, p::BenchmarkTools.Parameters; verbose::Bool, pad::String, kwargs::Base.Pairs{Symbol, Integer, NTuple{4, Symbol}, NamedTuple{(:samples, :evals, :gctrial, :gcsample), Tuple{Int64, Int64, Bool, Bool}}})
@ BenchmarkTools ~/.julia/packages/BenchmarkTools/7xSXH/src/execution.jl:99
[16] #invokelatest#2
@ ./essentials.jl:801 [inlined]
[17] #run_result#45
@ ~/.julia/packages/BenchmarkTools/7xSXH/src/execution.jl:34 [inlined]
[18] run(b::BenchmarkTools.Benchmark, p::BenchmarkTools.Parameters; progressid::Nothing, nleaves::Float64, ndone::Float64, kwargs::Base.Pairs{Symbol, Integer, NTuple{5, Symbol}, NamedTuple{(:verbose, :samples, :evals, :gctrial, :gcsample), Tuple{Bool, Int64, Int64, Bool, Bool}}})
@ BenchmarkTools ~/.julia/packages/BenchmarkTools/7xSXH/src/execution.jl:117
[19] #warmup#54
@ ~/.julia/packages/BenchmarkTools/7xSXH/src/execution.jl:169 [inlined]
[20] warmup(item::BenchmarkTools.Benchmark)
@ BenchmarkTools ~/.julia/packages/BenchmarkTools/7xSXH/src/execution.jl:168
[21] top-level scope
@ ~/.julia/packages/BenchmarkTools/7xSXH/src/execution.jl:393
[22] top-level scope
@ ~/.julia/packages/CUDA/qAl31/src/initialization.jl:52 The forward pass works just fine, so not sure what's going wrong. I'll try with a generated function and see if that makes a difference There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would recommend design specifically for the case where There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So I've used a generated function - it's pretty much the same as the one in Lux, but that one stops at a slightly different stage (it stops after the combination of the output and the next input, meaning N layers actually end up requiring N + 1 inputs unlike what's advertised). The gradient works now, but it's consistently about 6 µs off the pace from Lux when I benchmark it - not really sure why There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also let's try to avoid generated function if possible, we have such an aestethically pleasant codebase... The problem is array mutation. Maybe use comprehension instead? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. And as @darsnack suggested (m::PairwiseFusion)(x) = m(ntuple(i -> x, length(m.layers)))
function (m::PairwiseFusion)(x::Tuple)
... There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Okay, so I tried that, and it benchmarks more than 2x slower: Without generated functionfunction (m::PairwiseFusion)(x::Tuple)
nlayers = length(m.layers)
nx = length(x)
if nx != nlayers
throw(ArgumentError("PairwiseFusion with $nlayers layers takes $nlayers inputs, but got $nx inputs"))
end
out(::Val{1}) = m.layers[1](x[1])
out(::Val{i}) where i = m.layers[i](m.connection(out(Val(i - 1)), x[i]))
outputs = [out(Val(i)) for i in 1:nlayers]
return outputs
end Benchmarkjulia> model = PairwiseFusion(+, Dense(1, 30), Dense(30, 10));
julia> x = (rand(1, 10), rand(30, 10));
julia> @benchmark Zygote.gradient(p -> sum($model(p)[1]), $x)
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
Range (min … max): 50.833 μs … 7.276 ms ┊ GC (min … max): 0.00% … 98.10%
Time (median): 55.209 μs ┊ GC (median): 0.00%
Time (mean ± σ): 60.031 μs ± 174.896 μs ┊ GC (mean ± σ): 7.05% ± 2.40%
▄█▇▇▇▃▁
▁▁▃▄▅▇▆▇▇█████████▆▅▅▄▄▃▂▂▂▂▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ ▃
50.8 μs Histogram: frequency by time 70.8 μs <
Memory estimate: 53.98 KiB, allocs estimate: 258. Storing no intermittent inputs (i.e. just straightway calling julia> @benchmark Zygote.gradient(p -> sum($model(p)[1]), $x)
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
Range (min … max): 45.625 μs … 7.820 ms ┊ GC (min … max): 0.00% … 97.97%
Time (median): 48.916 μs ┊ GC (median): 0.00%
Time (mean ± σ): 52.350 μs ± 154.053 μs ┊ GC (mean ± σ): 5.80% ± 1.96%
▁▂▅▇▅█▆██▄▄▁▁
▁▁▁▂▂▂▃▄▆▆█████████████▇▆▅▄▄▃▃▃▂▂▂▂▂▁▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ ▃
45.6 μs Histogram: frequency by time 57.1 μs <
Memory estimate: 35.73 KiB, allocs estimate: 269. With generated functionfunction (m::PairwiseFusion)(x::T) where {T}
lx = length(x)
N = length(m.layers)
if T <: Tuple && lx != N
throw(ArgumentError("PairwiseFusion with $N sub-layers can take one input or $N inputs, but got $lx inputs"))
end
applypairwisefusion(m.layers, m.connection, x)
end
@generated function applypairwisefusion(layers::NamedTuple{names}, connection, x::T) where {names, T}
N = length(names)
y_symbols = [gensym() for _ in 1:(N + 1)]
getinput(i) = T <: Tuple ? :(x[$i]) : :x
calls = [:($(y_symbols[N + 1]) = $(getinput(1)))]
append!(calls,
[:($(y_symbols[i]) = layers[$i]($(y_symbols[N + 1]));
$(y_symbols[N + 1]) = connection($(y_symbols[i]), $(getinput(i + 1))))
for i in 1:N - 1])
push!(calls, :($(y_symbols[N]) = layers[$N]($(y_symbols[N + 1]))))
push!(calls, :(return $(y_symbols[N])))
return Expr(:block, calls...)
end Benchmark:julia> model = PairwiseFusion(+, Dense(1, 30), Dense(30, 10));
julia> x = (rand(1, 10), rand(30, 10));
julia> @benchmark Zygote.gradient(p -> sum($model(p)[1]), $x)
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
Range (min … max): 22.042 μs … 4.498 ms ┊ GC (min … max): 0.00% … 97.52%
Time (median): 24.125 μs ┊ GC (median): 0.00%
Time (mean ± σ): 27.066 μs ± 107.292 μs ┊ GC (mean ± σ): 9.55% ± 2.40%
▃▅▅▆▆▆▇▄█▆▅▃▂
▁▂▃▅███████████████▇▆▅▄▃▃▃▃▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ ▃
22 μs Histogram: frequency by time 31.8 μs <
Memory estimate: 28.23 KiB, allocs estimate: 138. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. out(::Val{1}) = m.layers[1](x[1])
out(::Val{i}) where i = m.layers[i](m.connection(out(Val(i - 1)), x[i]))
outputs = [out(Val(i)) for i in 1:nlayers] This doesn't reuse previous computation, so it scales quadratically with h = x[1]
out(::Val{1}) = m.layers[1](h)
out(::Val{i}) where i = m.layers[i](m.connection(h, x[i]))
outputs = [(h = out(Val(i)); h) for i in 1:nlayers] There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. or just h = x[1]
out(i) = i === 1 ? m.layers[1](h) : m.layers[i](m.connection(h, x[i]))
outputs = [(h = out(i); h) for i in 1:nlayers] if performance is the same There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I actually tried that later (I forgot to add the benchmarks for it). It does better than the purely recursive version but it's still way off the pace of the generated function: julia> @benchmark Zygote.gradient(p -> sum($model(p)[1]), $x)
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
Range (min … max): 41.458 μs … 6.534 ms ┊ GC (min … max): 0.00% … 96.79%
Time (median): 44.791 μs ┊ GC (median): 0.00%
Time (mean ± σ): 49.408 μs ± 140.782 μs ┊ GC (mean ± σ): 6.80% ± 2.37%
▃▅█▃
▁▂▂▄▅▇████▆▄▃▃▃▃▃▃▃▃▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ ▂
41.5 μs Histogram: frequency by time 64.3 μs <
Memory estimate: 49.86 KiB, allocs estimate: 354. Second version has performance nearabouts, I just checked |
||
|
||
@functor PairwiseFusion | ||
|
||
Base.getindex(m::PairwiseFusion, i) = m.layers[i] | ||
Base.getindex(m::PairwiseFusion, i::AbstractVector) = PairwiseFusion(m.connection, m.layers[i]) | ||
Base.getindex(m::PairwiseFusion{<:Any, <:NamedTuple}, i::AbstractVector) = | ||
PairwiseFusion(m.connection, NamedTuple{keys(m)[i]}(Tuple(m.layers)[i])) | ||
|
||
Base.keys(m::PairwiseFusion) = keys(getfield(m, :layers)) | ||
|
||
""" | ||
Embedding(in => out; init=randn) | ||
|
||
|
@@ -556,7 +642,7 @@ end | |
@functor Embedding | ||
|
||
Embedding((in, out)::Pair{<:Integer, <:Integer}; init = randn32) = Embedding(init(out, in)) | ||
|
||
(m::Embedding)(x::Integer) = m.weight[:, x] | ||
(m::Embedding)(x::AbstractVector) = NNlib.gather(m.weight, x) | ||
(m::Embedding)(x::AbstractArray) = reshape(m(vec(x)), :, size(x)...) | ||
|
@@ -565,7 +651,7 @@ function (m::Embedding)(x::Union{OneHotVector{T,L}, OneHotMatrix{T,L}}) where {T | |
size(m.weight, 2) == L || throw(DimensionMismatch("Matrix column must correspond with OneHot size: $(size(m.weight, 2)) != $L")) | ||
return m(onecold(x)) | ||
end | ||
|
||
function Base.show(io::IO, m::Embedding) | ||
print(io, "Embedding(", size(m.weight, 2), " => ", size(m.weight, 1), ")") | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This will get a test, I presume?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added on the other PR