From 356d09d23787ea74012bf5147ccbe14036985b3b Mon Sep 17 00:00:00 2001 From: cstjean Date: Mon, 7 Nov 2022 05:51:51 -0500 Subject: [PATCH 1/2] Support Vararg Chain (Chain of Parallel) Closes #2100 As mentionned in https://github.com/FluxML/Flux.jl/issues/2100#issuecomment-1305399770, this will break any code using `Chain()` as the identity function. --- src/layers/basic.jl | 25 ++++++++++++++----------- test/layers/basic.jl | 7 ++++++- 2 files changed, 20 insertions(+), 12 deletions(-) diff --git a/src/layers/basic.jl b/src/layers/basic.jl index 2a3bc9131c..dee92648e6 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -48,18 +48,21 @@ end @functor Chain -(c::Chain)(x) = _applychain(c.layers, x) +(c::Chain)(inputs...) = _applychain(c.layers, inputs...) -@generated function _applychain(layers::Tuple{Vararg{<:Any,N}}, x) where {N} - symbols = vcat(:x, [gensym() for _ in 1:N]) - calls = [:($(symbols[i+1]) = layers[$i]($(symbols[i]))) for i in 1:N] - Expr(:block, calls...) +@generated function _applychain(layers::Tuple{Vararg{<:Any,N}}, inputs...) where {N} + symbols = [gensym() for _ in 1:N] + calls = [:($(symbols[i]) = layers[$i]($(symbols[i-1]))) for i in 2:N] + Expr(:block, + :($(symbols[1]) = layers[1](inputs...)), + calls...) end -_applychain(layers::NamedTuple, x) = _applychain(Tuple(layers), x) +_applychain(layers::NamedTuple, inputs...) = _applychain(Tuple(layers), inputs...) -function _applychain(layers::AbstractVector, x) # type-unstable path, helps compile times - for f in layers +function _applychain(layers::AbstractVector, inputs...) # type-unstable path, helps compile times + x = layers[1](inputs...) + for f in @view(layers[2:end]) x = f(x) end x @@ -99,11 +102,11 @@ julia> activations(c, 1) (2, 4, 64) ``` """ -activations(c::Chain, input) = _extraChain(Tuple(c.layers), input) +activations(c::Chain, inputs...) = _extraChain(Tuple(c.layers), inputs...) # Calculates the forward results of each layer provided in a `Tuple` with `x` as model input. -function _extraChain(fs::Tuple, x) - res = first(fs)(x) +function _extraChain(fs::Tuple, inputs...) + res = first(fs)(inputs...) return (res, _extraChain(Base.tail(fs), res)...) end _extraChain(::Tuple{}, x) = () diff --git a/test/layers/basic.jl b/test/layers/basic.jl index 1f9d30dec5..5bc522766f 100644 --- a/test/layers/basic.jl +++ b/test/layers/basic.jl @@ -1,5 +1,5 @@ using Test, Random -import Flux: activations +import Flux: activations, OneHotArray, OneHotMatrix, OneHotVector, onehotbatch, params, Zygote @testset "basic" begin @testset "helpers" begin @@ -216,6 +216,11 @@ import Flux: activations @test size(Parallel(hcat, one = Dense(10, 10), two = identity)(input)) == (10, 4) end + @testset "parallel chain" begin + inputs = (randn(2, 10), randn(3, 10)) + @test size(Chain(Parallel(vcat, Dense(2, 5), identity), Dense(8, 4))(inputs...)) == (4, 10) + end + @testset "vararg input" begin inputs = randn(10), randn(5), randn(4) @test size(Parallel(+, Dense(10, 2), Dense(5, 2), Dense(4, 2))(inputs)) == (2,) From 0b19bcc647c44dcef6b7e1d4c4c8dc819860b828 Mon Sep 17 00:00:00 2001 From: cstjean Date: Mon, 7 Nov 2022 06:08:23 -0500 Subject: [PATCH 2/2] Documentation for Chain(Parallel(...)) --- src/layers/basic.jl | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/layers/basic.jl b/src/layers/basic.jl index dee92648e6..bd1ab34ea3 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -493,6 +493,12 @@ julia> model2[:α](rand(10)) |> size julia> model2[:β] == model2[2] true + +julia> model3 = Chain(Parallel(+, Dense(5 => 4), Embedding(15=>4)), + Dense(4 => 17)); + +julia> model3(randn(5, 10), rand(1:15, 10)) |> size +(17, 10) ``` """ struct Parallel{F, T<:Union{Tuple, NamedTuple}}