Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PairwiseFusion layer, take 2 #1983

Merged
merged 14 commits into from
Jun 6, 2022
3 changes: 3 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
# Flux Release Notes

## v0.13.4
* Added [`PairwiseFusion` layer](https://github.com/FluxML/Flux.jl/pull/1983)

## v0.13
* After a deprecations cycle, the datasets in `Flux.Data` have
been removed in favour of MLDatasets.jl.
Expand Down
2 changes: 1 addition & 1 deletion src/Flux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ export gradient
# Pirate error to catch a common mistake. (Internal function `base` because overloading `update!` is more likely to give ambiguities.)
Optimisers.base(dx::Zygote.Grads) = error("Optimisers.jl cannot be used with Zygote.jl's implicit gradients, `Params` & `Grads`")

export Chain, Dense, Maxout, SkipConnection, Parallel,
export Chain, Dense, Maxout, SkipConnection, Parallel, PairwiseFusion,
RNN, LSTM, GRU, GRUv3,
SamePad, Conv, CrossCor, ConvTranspose, DepthwiseConv,
AdaptiveMaxPool, AdaptiveMeanPool, GlobalMaxPool, GlobalMeanPool, MaxPool, MeanPool,
Expand Down
112 changes: 106 additions & 6 deletions src/layers/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ end

Base.getindex(c::Chain, i::AbstractArray) = Chain(c.layers[i])
Base.getindex(c::Chain{<:NamedTuple}, i::AbstractArray) =
Chain(NamedTuple{Base.keys(c)[i]}(Tuple(c.layers)[i]))
Chain(NamedTuple{keys(c)[i]}(Tuple(c.layers)[i]))
function Base.show(io::IO, c::Chain)
print(io, "Chain(")
_show_layers(io, c.layers)
Expand Down Expand Up @@ -487,7 +487,7 @@ end
Parallel(connection, layers...) = Parallel(connection, layers)
function Parallel(connection; kw...)
layers = NamedTuple(kw)
if :layers in Base.keys(layers) || :connection in Base.keys(layers)
if :layers in keys(layers) || :connection in keys(layers)
throw(ArgumentError("a Parallel layer cannot have a named sub-layer called `connection` or `layers`"))
end
isempty(layers) && return Parallel(connection, ())
Expand All @@ -510,16 +510,116 @@ end
Base.getindex(m::Parallel, i) = m.layers[i]
Base.getindex(m::Parallel, i::AbstractVector) = Parallel(m.connection, m.layers[i])
Base.getindex(m::Parallel{<:Any, <:NamedTuple}, i::AbstractVector) =
Parallel(m.connection, NamedTuple{Base.keys(m)[i]}(Tuple(m.layers)[i]))
Parallel(m.connection, NamedTuple{keys(m)[i]}(Tuple(m.layers)[i]))

Base.keys(m::Parallel) = Base.keys(getfield(m, :layers))
Base.keys(m::Parallel) = keys(getfield(m, :layers))

function Base.show(io::IO, m::Parallel)
print(io, "Parallel(", m.connection, ", ")
_show_layers(io, m.layers)
print(io, ")")
end

"""
PairwiseFusion(connection, layers...)

```
x1 --> layer1 --> y1
|
|--> connection --> layer2 --> y2
| |
x2 |--> connection --> layer3 --> y3
| |
x3 |--> connection --> y4
|
x4
```
mcabbott marked this conversation as resolved.
Show resolved Hide resolved

## Arguments

- `connection`: Takes 2 inputs and combines them
theabhirath marked this conversation as resolved.
Show resolved Hide resolved
- `layers`: The layers whose outputs are combined

## Inputs

This layer behaves differently based on input type:

1. Input `x` is a tuple of length `N`. Then `layers` must be a tuple of length `N`. The computation is as follows:

```julia
y = x[1]
for i in 1:N
y = connection(x[i], layers[i](y))
end
```

2. Any other kind of input:

```julia
y = x
for i in 1:N
y = connection(x, layers[i](y))
end
```

## Returns

A tuple of length `N` with the output of each fusion ((`y1`, `y2`, ..., `yN`) in the example above).
"""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be good to add a concrete example

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's one in the tests, but I was holding back because I wanted to take time later to come up with a better example

Copy link
Member

@mcabbott mcabbott Jun 6, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about:

julia> PairwiseFusion(vcat, x->x+1, x->x.+2, x->x.^3)(2)
(3, [5, 4], [125, 64, 8])

julia> PairwiseFusion(vcat, x->x+1, x->x.+2, x->x.^3)((2, 10, 20))
(3, [5, 12], [125, 1728, 8000])

struct PairwiseFusion{F, T <: NamedTuple}
connection::F
layers::T
end

function PairwiseFusion(connection, layers...)
names = ntuple(i -> Symbol("layer_$i"), length(layers))
return PairwiseFusion(connection, NamedTuple{names}(layers))
end
theabhirath marked this conversation as resolved.
Show resolved Hide resolved

function _pairwise_check(lx, N, T)
if T <: Tuple && lx != N
throw(ArgumentError("PairwiseFusion with $N sub-layers can take one input or $N inputs, but got $lx inputs"))
end
end
ChainRulesCore.@non_differentiable _pairwise_check(lx, N, T)

function (m::PairwiseFusion)(x::T) where {T}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this, like Parallel, allow m(x1, x2, x3) == m((x1, x2, x3))?

I also wonder if the one-x case should be x::AbstractArray{<:Number}, or something. So that we don't find out someone is relying on some unintended behaviour, e.g. how a NamedTuple is handled. Although Parallel does not make such a restriction.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The first suggestion seems okay. The second would make the layer less usable if the sub-layers are custom layers that accept something other than the type restriction that we provide. Presumably, in most cases, the sub-layers should be appropriately restricted to throw an error?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've allowed the first. Not sure about the second...

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see the method allowing this:

julia> PairwiseFusion(vcat, x->x.+1, x->x.+2, x->x.^3)(2, 10, 20)
ERROR: MethodError: no method matching (::PairwiseFusion{typeof(vcat), Tuple{var"#130#133", var"#131#134", var"#132#135"}})(::Int64, ::Int64, ::Int64)
Closest candidates are:
  (::PairwiseFusion)(::T) where T at REPL[11]:1

Agree that the second is less obvious. In particular it rules out many easy readme examples.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Whoops, yeah, I think I'd missed pushing that - should be fixed now

lx = length(x)
N = length(m.layers)
_pairwise_check(lx, N, T)
applypairwisefusion(m.layers, m.connection, x)
end

@generated function applypairwisefusion(layers::NamedTuple{names}, connection, x::T) where {names, T}
N = length(names)
y_symbols = [gensym() for _ in 1:(N + 1)]
getinput(i) = T <: Tuple ? :(x[$i]) : :x
calls = [:($(y_symbols[N + 1]) = $(getinput(1)))]
for i in 1:N - 1
push!(calls, quote
$(y_symbols[i]) = layers[$i]($(y_symbols[N + 1]))
$(y_symbols[N + 1]) = connection($(y_symbols[i]), $(getinput(i + 1)))
end)
end
push!(calls, :($(y_symbols[N]) = layers[$N]($(y_symbols[N + 1]))))
push!(calls, :(return tuple($(Tuple(y_symbols[1:N])...))))
return Expr(:block, calls...)
end

@functor PairwiseFusion

Base.getindex(m::PairwiseFusion, i) = m.layers[i]
Base.getindex(m::PairwiseFusion{<:Any, <:NamedTuple}, i::AbstractVector) =
PairwiseFusion(m.connection, NamedTuple{keys(m)[i]}(Tuple(m.layers)[i]))
theabhirath marked this conversation as resolved.
Show resolved Hide resolved

Base.keys(m::PairwiseFusion) = keys(getfield(m, :layers))

function Base.show(io::IO, m::PairwiseFusion)
print(io, "PairwiseFusion(", m.connection, ", ")
_show_layers(io, m.layers)
print(io, ")")
end

"""
Embedding(in => out; init=randn)

Expand Down Expand Up @@ -556,7 +656,7 @@ end
@functor Embedding

Embedding((in, out)::Pair{<:Integer, <:Integer}; init = randn32) = Embedding(init(out, in))

(m::Embedding)(x::Integer) = m.weight[:, x]
(m::Embedding)(x::AbstractVector) = NNlib.gather(m.weight, x)
(m::Embedding)(x::AbstractArray) = reshape(m(vec(x)), :, size(x)...)
Expand All @@ -565,7 +665,7 @@ function (m::Embedding)(x::Union{OneHotVector{T,L}, OneHotMatrix{T,L}}) where {T
size(m.weight, 2) == L || throw(DimensionMismatch("Matrix column must correspond with OneHot size: $(size(m.weight, 2)) != $L"))
return m(onecold(x))
end

function Base.show(io::IO, m::Embedding)
print(io, "Embedding(", size(m.weight, 2), " => ", size(m.weight, 1), ")")
end
3 changes: 2 additions & 1 deletion src/layers/show.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@

for T in [
:Chain, :Parallel, :SkipConnection, :Recur, :Maxout # container types
:Chain, :Parallel, :SkipConnection, :Recur, :Maxout, :PairwiseFusion # container types
]
@eval function Base.show(io::IO, m::MIME"text/plain", x::$T)
if get(io, :typeinfo, nothing) === nothing # e.g. top level in REPL
Expand Down Expand Up @@ -53,6 +53,7 @@ _show_children(x) = trainable(x) # except for layers which hide their Tuple:
_show_children(c::Chain) = c.layers
_show_children(m::Maxout) = m.layers
_show_children(p::Parallel) = (p.connection, p.layers...)
_show_children(f::PairwiseFusion) = (f.connection, f.layers...)

for T in [
:Conv, :ConvTranspose, :CrossCor, :Dense, :Scale, :Bilinear, :Embedding,
Expand Down
16 changes: 16 additions & 0 deletions test/layers/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -350,3 +350,19 @@ end
@test Flux.destructure(m1)[2](z1)[1].weight ≈ Flux.destructure(m1v)[2](z1)[1].weight
# Note that Flux.destructure(m1v)[2](z) has a Chain{Tuple}, as does m1v[1:2]
end

@testset "PairwiseFusion" begin
x = (rand(1, 10), rand(30, 10))
layer = PairwiseFusion(+, Dense(1, 30), Dense(30, 10))
ToucheSir marked this conversation as resolved.
Show resolved Hide resolved
y = layer(x)
@test length(y) == 2
@test size(y[1]) == (30, 10)
@test size(y[2]) == (10, 10)

x = rand(1, 10)
layer = PairwiseFusion(.+, Dense(1, 10), Dense(10, 1))
y = layer(x)
@test length(y) == 2
@test size(y[1]) == (10, 10)
@test size(y[2]) == (1, 10)
end