Merge pull request #1983 from theabhirath/pairwise-fusion-2
`PairwiseFusion` layer, take 2
ToucheSir authored Jun 6, 2022
2 parents f86b356 + d0f0a29 commit 0b01b77
# Flux Release Notes

## v0.13.4
* Added [`PairwiseFusion` layer](

## v0.13
* After a deprecations cycle, the datasets in `Flux.Data` have
been removed in favour of MLDatasets.jl.
# Pirate error to catch a common mistake. (Internal function `base` because overloading `update!` is more likely to give ambiguities.)
Optimisers.base(dx::Zygote.Grads) = error("Optimisers.jl cannot be used with Zygote.jl's implicit gradients, `Params` & `Grads`")

export Chain, Dense, Maxout, SkipConnection, Parallel,
export Chain, Dense, Maxout, SkipConnection, Parallel, PairwiseFusion,
SamePad, Conv, CrossCor, ConvTranspose, DepthwiseConv,
AdaptiveMaxPool, AdaptiveMeanPool, GlobalMaxPool, GlobalMeanPool, MaxPool, MeanPool,
Expand Up @@ -38,7 +38,7 @@ end

Chain(xs...) = Chain(xs)
function Chain(; kw...)
:layers in Base.keys(kw) && throw(ArgumentError("a Chain cannot have a named layer called `layers`"))
:layers in keys(kw) && throw(ArgumentError("a Chain cannot have a named layer called `layers`"))
isempty(kw) && return Chain(())
Base.getindex(c::Chain, i::AbstractArray) = Chain(c.layers[i])
Base.getindex(c::Chain{<:NamedTuple}, i::AbstractArray) =
function, c::Chain)
print(io, "Chain(")
_show_layers(io, c.layers)
Parallel(connection, layers...) = Parallel(connection, layers)
function Parallel(connection; kw...)
layers = NamedTuple(kw)
if :layers in Base.keys(layers) || :connection in Base.keys(layers)
if :layers in keys(layers) || :connection in keys(layers)
throw(ArgumentError("a Parallel layer cannot have a named sub-layer called `connection` or `layers`"))
isempty(layers) && return Parallel(connection, ())
(m::Parallel)(x) = m.connection(map(f -> f(x), Tuple(m.layers))...)
(m::Parallel)(xs::Tuple) = m(xs...)
function (m::Parallel)(xs...)
nl = length(m.layers)
nx = length(xs)
function _parallel_check(layers, xs)
nl = length(layers)
nx = length(xs)
if (nl != nx)
throw(ArgumentError("Parallel with $nl sub-layers can take one input or $nl inputs, but got $nx inputs"))
function (m::Parallel)(xs...)
_parallel_check(m.layers, xs)
m.connection(map(|>, xs, Tuple(m.layers))...)

Base.getindex(m::Parallel, i::AbstractVector) = Parallel(m.connection, m.layers[i])
Base.getindex(m::Parallel{<:Any, <:NamedTuple}, i::AbstractVector) =
Parallel(m.connection, NamedTuple{Base.keys(m)[i]}(Tuple(m.layers)[i]))
Base.keys(m::Parallel) = Base.keys(getfield(m, :layers))
function, m::Parallel)
print(io, "Parallel(", m.connection, ", ")
_show_layers(io, m.layers)
PairwiseFusion(connection, layers...)
## Arguments
- `connection`: A function taking 2 inputs and combining them into a single output
- `layers`: The layers whose outputs are combined
## Inputs
This layer behaves differently based on input type:
1. If input `x` is a tuple of length N (or the input is `xs` with N `x`'s), matching the number of `layers`,
then each layer receives a new input `x[i]` combined with the previous output `y[i-1]` using `connection`.
Thus `(y1, y2, y3) = PairwiseFusion(connection, layer1, layer2, layer3)((x1, x2, x3))`
may be drawn as:
x1 → layer1 → y1 ↘
connection → layer2 → y2 ↘
x2 ↗ connection → layer3 → y3
x3 ↗
... or written as:
y1 = layer1(x1)
y2 = layer2(connection(x2, y1))
y3 = layer3(connection(x3, y2))
2. With just one input, each layer receives the same `x` combined with the previous output.
Thus `y = PairwiseFusion(connection, layers...)(x)` obeys:
y[1] == layers[1](x)
for i in 2:length(layers)
y[i] == connection(x, layers[i](y[i-1]))
## Returns
A tuple of length N with the output of each fusion ((`y1`, `y2`, ..., `yN`) in the example above).
PairwiseFusion(connection, layers...) = PairwiseFusion(connection, layers)
function PairwiseFusion(connection; kw...)
layers = NamedTuple(kw)
if :layers in keys(layers) || :connection in keys(layers)
throw(ArgumentError("a PairwiseFusion layer cannot have a named sub-layer called `connection` or `layers`"))
isempty(layers) && return PairwiseFusion(connection, ())
function _pairwise_check(x, layers, T)
lx = length(x)
N = length(layers)
if T <: Tuple && lx != N
throw(ArgumentError("PairwiseFusion with $N sub-layers can take one input or $N inputs, but got $lx inputs"))
function (m::PairwiseFusion)(x::T) where {T}
_pairwise_check(x, m.layers, T)
applypairwisefusion(m.layers, m.connection, x)
@generated function applypairwisefusion(layers::Tuple{Vararg{<:Any,N}}, connection, x::T) where {N, T}
y_symbols = [gensym() for _ in 1:(N + 1)]
getinput(i) = T <: Tuple ? :(x[$i]) : :x
calls = [:($(y_symbols[N + 1]) = $(getinput(1)))]
for i in 1:N - 1
push!(calls, quote
$(y_symbols[i]) = layers[$i]($(y_symbols[N + 1]))
$(y_symbols[N + 1]) = connection($(y_symbols[i]), $(getinput(i + 1)))
push!(calls, :($(y_symbols[N]) = layers[$N]($(y_symbols[N + 1]))))
push!(calls, :(return tuple($(Tuple(y_symbols[1:N])...))))
return Expr(:block, calls...)
Base.getindex(m::PairwiseFusion, i) = m.layers[i]
Base.getindex(m::PairwiseFusion, i::AbstractVector) = PairwiseFusion(m.connection, m.layers[i])
Base.getindex(m::PairwiseFusion{<:Any, <:NamedTuple}, i::AbstractVector) =
PairwiseFusion(m.connection, NamedTuple{keys(m)[i]}(Tuple(m.layers)[i]))

Base.keys(m::PairwiseFusion) = keys(getfield(m, :layers))

function, m::PairwiseFusion)
print(io, "PairwiseFusion(", m.connection, ", ")
_show_layers(io, m.layers)
Embedding(in => out; init=randn)
Expand Down Expand Up @@ -556,7 +666,7 @@ end
@functor Embedding

(m::Embedding)(x::Integer) = m.weight[:, x]
(m::Embedding)(x::AbstractVector) = NNlib.gather(m.weight, x)
(m::Embedding)(x::AbstractArray) = reshape(m(vec(x)), :, size(x)...)
Expand All @@ -565,7 +675,7 @@ function (m::Embedding)(x::Union{OneHotVector{T,L}, OneHotMatrix{T,L}}) where {T
size(m.weight, 2) == L || throw(DimensionMismatch("Matrix column must correspond with OneHot size: $(size(m.weight, 2)) != $L"))
function, m::Embedding)
print(io, "Embedding(", size(m.weight, 2), " => ", size(m.weight, 1), ")")
for T in [
:Chain, :Parallel, :SkipConnection, :Recur, :Maxout # container types
:Chain, :Parallel, :SkipConnection, :Recur, :Maxout, :PairwiseFusion # container types
@eval function, m::MIME"text/plain", x::$T)
if get(io, :typeinfo, nothing) === nothing # e.g. top level in REPL
Expand All @@ -25,7 +25,7 @@ function _big_show(io::IO, obj, indent::Int=0, name=nothing)
for k in Base.keys(obj)
_big_show(io, obj[k], indent+2, k)
elseif obj isa Parallel{<:Any, <:NamedTuple}
elseif obj isa Parallel{<:Any, <:NamedTuple} || obj isa PairwiseFusion{<:Any, <:NamedTuple}
_big_show(io, obj.connection, indent+2)
for k in Base.keys(obj)
_big_show(io, obj[k], indent+2, k)
_show_children(c::Chain) = c.layers
_show_children(m::Maxout) = m.layers
_show_children(p::Parallel) = (p.connection, p.layers...)
for T in [
:Conv, :ConvTranspose, :CrossCor, :Dense, :Scale, :Bilinear, :Embedding,
Expand Up @@ -350,3 +350,22 @@ end
@test Flux.destructure(m1)[2](z1)[1].weight Flux.destructure(m1v)[2](z1)[1].weight
@testset "PairwiseFusion" begin
x = (rand(1, 10), rand(30, 10))
layer = PairwiseFusion(+, Dense(1, 30), Dense(30, 10))
y = layer(x)
@test length(y) == 2
@test size(y[1]) == (30, 10)
x = rand(1, 10)
layer = PairwiseFusion(.+, Dense(1, 10), Dense(10, 1))
y = layer(x)
@test length(y) == 2
@test size(y[1]) == (10, 10)
@test PairwiseFusion(vcat, x->x.+1, x->x.+2, x->x.^3)(2, 10, 20) == (3, [5, 12], [125, 1728, 8000])
@test PairwiseFusion(vcat, x->x.+1, x->x.+2, x->x.^3)(7) == (8, [10, 9], [1000, 729, 343])

