From e0ee619d8ed8c51dbb54131dff955f427aa5dfaa Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Wed, 2 Feb 2022 23:21:52 -0500 Subject: [PATCH 1/7] simple functor Chain --- src/layers/basic.jl | 23 ++++++++++++----------- test/layers/basic.jl | 3 +++ 2 files changed, 15 insertions(+), 11 deletions(-) diff --git a/src/layers/basic.jl b/src/layers/basic.jl index e40457ef53..69f6f8c226 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -28,29 +28,30 @@ julia> m2(x) == (m2[:dec] ∘ m2[:enc])(x) true ``` """ -struct Chain{T} +struct Chain{T<:Union{Tuple, NamedTuple}} layers::T - Chain(xs...) = new{typeof(xs)}(xs) - function Chain(; kw...) - :layers in Base.keys(kw) && throw(ArgumentError("a Chain cannot have a named layer called `layers`")) - isempty(kw) && return new{Tuple{}}(()) - new{typeof(values(kw))}(values(kw)) - end +end + +Chain(xs...) = Chain(xs) +function Chain(; kw...) + :layers in Base.keys(kw) && throw(ArgumentError("a Chain cannot have a named layer called `layers`")) + isempty(kw) && return Chain(()) + Chain(values(kw)) end @forward Chain.layers Base.getindex, Base.length, Base.first, Base.last, Base.iterate, Base.lastindex, Base.keys -functor(::Type{<:Chain}, c) = c.layers, ls -> Chain(ls...) +@functor Chain applychain(::Tuple{}, x) = x applychain(fs::Tuple, x) = applychain(tail(fs), first(fs)(x)) (c::Chain)(x) = applychain(Tuple(c.layers), x) -Base.getindex(c::Chain, i::AbstractArray) = Chain(c.layers[i]...) -Base.getindex(c::Chain{<:NamedTuple}, i::AbstractArray) = - Chain(; NamedTuple{Base.keys(c)[i]}(Tuple(c.layers)[i])...) +Base.getindex(c::Chain, i::AbstractArray) = Chain(c.layers[i]) +Base.getindex(c::Chain{<:NamedTuple}, i::AbstractArray) = + Chain(NamedTuple{Base.keys(c)[i]}(Tuple(c.layers)[i])) function Base.show(io::IO, c::Chain) print(io, "Chain(") diff --git a/test/layers/basic.jl b/test/layers/basic.jl index b362f55f16..d2e213b849 100644 --- a/test/layers/basic.jl +++ b/test/layers/basic.jl @@ -25,6 +25,9 @@ import Flux: activations @test m[:first] == m[1] @test m[1:2] == m + @test m == m + @test m == fmap(identity, m) # does not forget names + @test_throws ArgumentError Chain(layers = Dense(10, 10), two = identity) # reserved name end From 527148daea825342dff0ff7de7e09f07f9d7604d Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Wed, 2 Feb 2022 23:23:36 -0500 Subject: [PATCH 2/7] simplify Maxout --- src/layers/basic.jl | 18 ++++++------------ 1 file changed, 6 insertions(+), 12 deletions(-) diff --git a/src/layers/basic.jl b/src/layers/basic.jl index 69f6f8c226..4ee3c2407f 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -246,29 +246,23 @@ julia> Flux.outputsize(m3, (5, 11)) (7, 11) ``` """ -struct Maxout{FS<:Tuple} - over::FS - Maxout(layers...) = new{typeof(layers)}(layers) -end - -function Maxout(f::Function, n_alts::Integer) - over = Tuple(f() for _ in 1:n_alts) - return Maxout(over...) +struct Maxout{T<:Tuple} + layers::T end +Maxout(layers...) = Maxout(layers) +Maxout(f::Function, n_alts::Integer) = Maxout((f() for _ in 1:n_alts)...) @functor Maxout function (mo::Maxout)(input::AbstractArray) # Perhaps surprisingly, pairwise max broadcast is often faster, # even with Zygote. See #698 and #1794 - mapreduce(f -> f(input), (acc, out) -> max.(acc, out), mo.over) + mapreduce(f -> f(input), (acc, out) -> max.(acc, out), mo.layers) end -trainable(mo::Maxout) = mo.over - function Base.show(io::IO, mo::Maxout) print(io, "Maxout(") - _show_layers(io, mo.over) + _show_layers(io, mo.layers) print(io, ")") end From ab0e8c2904cc70e9624b8fc8265a903d46f1f2b5 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Wed, 2 Feb 2022 23:24:35 -0500 Subject: [PATCH 3/7] fix show as a result --- src/layers/show.jl | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/layers/show.jl b/src/layers/show.jl index 791d2511ca..85faec3c59 100644 --- a/src/layers/show.jl +++ b/src/layers/show.jl @@ -14,7 +14,7 @@ for T in [ end function _big_show(io::IO, obj, indent::Int=0, name=nothing) - children = trainable(obj) + children = _show_children(obj) if all(_show_leaflike, children) _layer_show(io, obj, indent, name) else @@ -48,6 +48,11 @@ _show_leaflike(::Tuple{Vararg{<:Number}}) = true # e.g. stride of Conv _show_leaflike(::Tuple{Vararg{<:AbstractArray}}) = true # e.g. parameters of LSTMcell _show_leaflike(::Diagonal) = true # appears inside LayerNorm +_show_children(x) = trainable(x) # except for layers which hide their Tuple: +_show_children(c::Chain) = c.layers +_show_children(m::Maxout) = m.layers +_show_children(p::Parallel) = (p.connection, p.layers...) + for T in [ :Conv, :ConvTranspose, :CrossCor, :DepthwiseConv, :Dense, :BatchNorm, :LayerNorm, :InstanceNorm, :GroupNorm, From fd1a5376f3543e83ece61ce68b9b9dfd9ca25eff Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Wed, 2 Feb 2022 23:26:17 -0500 Subject: [PATCH 4/7] trainable always a NamedTuple --- src/layers/normalise.jl | 12 +++++------- src/layers/recurrent.jl | 2 +- 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/src/layers/normalise.jl b/src/layers/normalise.jl index 164fc0d782..b0dec16241 100644 --- a/src/layers/normalise.jl +++ b/src/layers/normalise.jl @@ -82,8 +82,7 @@ function Dropout(p; dims=:, rng = rng_from_array()) end @functor Dropout - -trainable(a::Dropout) = () +trainable(a::Dropout) = (;) function (a::Dropout)(x) _isactive(a) || return x @@ -122,8 +121,7 @@ AlphaDropout(p, active) = AlphaDropout(p, active, rng_from_array()) AlphaDropout(p; rng = rng_from_array()) = AlphaDropout(p, nothing, rng) @functor AlphaDropout - -trainable(a::AlphaDropout) = () +trainable(a::AlphaDropout) = (;) function (a::AlphaDropout)(x::AbstractArray{T}) where T _isactive(a) || return x @@ -288,7 +286,7 @@ function BatchNorm(chs::Int, λ=identity; end @functor BatchNorm -trainable(bn::BatchNorm) = hasaffine(bn) ? (bn.β, bn.γ) : () +trainable(bn::BatchNorm) = hasaffine(bn) ? (β = bn.β, γ = bn.γ) : (;) function (BN::BatchNorm)(x) @assert size(x, ndims(x)-1) == BN.chs @@ -364,7 +362,7 @@ function InstanceNorm(chs::Int, λ=identity; end @functor InstanceNorm -trainable(in::InstanceNorm) = hasaffine(in) ? (in.β, in.γ) : () +trainable(in::InstanceNorm) = hasaffine(in) ? (β = in.β, γ = in.γ) : (;) function (l::InstanceNorm)(x) @assert ndims(x) > 2 @@ -426,7 +424,7 @@ mutable struct GroupNorm{F,V,N,W} end @functor GroupNorm -trainable(gn::GroupNorm) = hasaffine(gn) ? (gn.β, gn.γ) : () +trainable(gn::GroupNorm) = hasaffine(gn) ? (β = gn.β, γ = gn.γ) : (;) function GroupNorm(chs::Int, G::Int, λ=identity; initβ=zeros32, initγ=ones32, diff --git a/src/layers/recurrent.jl b/src/layers/recurrent.jl index 87f77c565a..c8a34354d7 100644 --- a/src/layers/recurrent.jl +++ b/src/layers/recurrent.jl @@ -65,7 +65,7 @@ function (m::Recur)(x) end @functor Recur -trainable(a::Recur) = (a.cell,) +trainable(a::Recur) = (; cell = a.cell) Base.show(io::IO, m::Recur) = print(io, "Recur(", m.cell, ")") From a83abe8c245e1b85a3744b37f2b647454f964cf5 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Fri, 4 Feb 2022 21:56:39 -0500 Subject: [PATCH 5/7] Parallel: delete trainable, call combiner once --- src/layers/basic.jl | 26 ++++++++------ test/layers/basic.jl | 25 +++++++++++++ test/runtests.jl | 85 +++++++++++++++++++++++--------------------- 3 files changed, 85 insertions(+), 51 deletions(-) diff --git a/src/layers/basic.jl b/src/layers/basic.jl index 4ee3c2407f..230e696115 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -409,8 +409,8 @@ end Create a `Parallel` layer that passes an input array to each path in `layers`, before reducing the output with `connection`. -Called with one input `x`, this is equivalent to `reduce(connection, [l(x) for l in layers])`. -If called with multiple inputs, they are `zip`ped with the layers, thus `Parallel(+, f, g)(x, y) = f(x) + g(y)`. +Called with one input `x`, this is equivalent to `connection([l(x) for l in layers]...)`. +If called with multiple inputs, one is passed to each layer, thus `Parallel(+, f, g)(x, y) = f(x) + g(y)`. Like [`Chain`](@ref), its sub-layers may be given names using the keyword constructor. These can be accessed by indexing: `m[1] == m[:name]` is the first layer. @@ -445,7 +445,7 @@ julia> model2[:β] == model2[2] true ``` """ -struct Parallel{F, T} +struct Parallel{F, T<:Union{Tuple, NamedTuple}} connection::F layers::T end @@ -455,25 +455,31 @@ function Parallel(connection; kw...) layers = NamedTuple(kw) if :layers in Base.keys(layers) || :connection in Base.keys(layers) throw(ArgumentError("a Parallel layer cannot have a named sub-layer called `connection` or `layers`")) - elseif isempty(layers) - Parallel(connection, ()) end + isempty(layers) && return Parallel(connection, ()) Parallel(connection, layers) end @functor Parallel -(m::Parallel)(x) = mapreduce(f -> f(x), m.connection, Tuple(m.layers)) -(m::Parallel)(xs...) = mapreduce((f, x) -> f(x), m.connection, Tuple(m.layers), xs) +(m::Parallel)(x) = m.connection(map(f -> f(x), Tuple(m.layers))...) (m::Parallel)(xs::Tuple) = m(xs...) +function (m::Parallel)(xs...) + nl = length(m.layers) + nx = length(xs) + if nl != nx + throw(ArgumentError("Parallel with $nl sub-layers can take one input or $nl inputs, but got $nx inputs")) + end + m.connection(map(|>, xs, Tuple(m.layers))...) +end Base.getindex(m::Parallel, i) = m.layers[i] -Base.getindex(m::Parallel, i::AbstractVector) = Parallel(m.connection, m.layers[i]...) +Base.getindex(m::Parallel, i::AbstractVector) = Parallel(m.connection, m.layers[i]) +Base.getindex(m::Parallel{<:Any, <:NamedTuple}, i::AbstractVector) = + Parallel(m.connection, NamedTuple{Base.keys(m)[i]}(Tuple(m.layers)[i])) Base.keys(m::Parallel) = Base.keys(getfield(m, :layers)) -trainable(m::Parallel) = (m.connection, m.layers...) - function Base.show(io::IO, m::Parallel) print(io, "Parallel(", m.connection, ", ") _show_layers(io, m.layers) diff --git a/test/layers/basic.jl b/test/layers/basic.jl index d2e213b849..968ddd506f 100644 --- a/test/layers/basic.jl +++ b/test/layers/basic.jl @@ -205,14 +205,39 @@ import Flux: activations inputs = randn(10), randn(5), randn(4) @test size(Parallel(+, Dense(10, 2), Dense(5, 2), Dense(4, 2))(inputs)) == (2,) @test size(Parallel(+; a = Dense(10, 2), b = Dense(5, 2), c = Dense(4, 2))(inputs)) == (2,) + @test_throws ArgumentError Parallel(+, sin, cos)(1,2,3) # wrong number of inputs + @test Parallel(+, sin, cos)(pi/2) ≈ 1 end @testset "named access" begin m = Parallel(hcat, one = Dense(10, 10), two = identity) @test m[1] == m[:one] + @test m[1:2] == m @test_throws ArgumentError Parallel(hcat, layers = Dense(10, 10), two = identity) # reserved names @test_throws ArgumentError Parallel(hcat, connection = Dense(10, 10), two = identity) + + @test m == fmap(identity, m) # does not forget names + + @test Parallel(vcat, x = log)(1) == [0] + @test Parallel(vcat, log)(1) == [0] + end + + @testset "trivial cases" begin + @test Parallel(hcat) isa Parallel{typeof(hcat), Tuple{}} # not a NamedTuple + @test Parallel(hcat)(1) == hcat() + @test Parallel(hcat, inv)(2) == hcat(1/2) # still calls connection once. + end + + @testset "connection is called once" begin + CNT = Ref(0) + f_cnt = (x...) -> (CNT[]+=1; +(x...)) + Parallel(f_cnt, sin, cos, tan)(1) + @test CNT[] == 1 + Parallel(f_cnt, sin, cos, tan)(1,2,3) + @test CNT[] == 2 + Parallel(f_cnt, sin)(1) + @test CNT[] == 3 end # Ref https://github.com/FluxML/Flux.jl/issues/1673 diff --git a/test/runtests.jl b/test/runtests.jl index 781edb549d..a6abd609d2 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -8,55 +8,58 @@ using CUDA Random.seed!(0) -@testset "Utils" begin - include("utils.jl") -end +@testset verbose=true "Flux.jl" begin -@testset "Onehot" begin - include("onehot.jl") -end + @testset "Utils" begin + include("utils.jl") + end -@testset "Optimise" begin - include("optimise.jl") -end + @testset "Onehot" begin + include("onehot.jl") + end -@testset "Data" begin - include("data.jl") -end + @testset "Optimise" begin + include("optimise.jl") + end -@testset "Losses" begin - include("losses.jl") - include("ctc.jl") - CUDA.functional() && include("ctc-gpu.jl") -end + @testset "Data" begin + include("data.jl") + end -@testset "Layers" begin - include("layers/basic.jl") - include("layers/normalisation.jl") - include("layers/stateless.jl") - include("layers/recurrent.jl") - include("layers/conv.jl") - include("layers/upsample.jl") - include("layers/show.jl") -end + @testset "Losses" begin + include("losses.jl") + include("ctc.jl") + CUDA.functional() && include("ctc-gpu.jl") + end -@testset "outputsize" begin - using Flux: outputsize - include("outputsize.jl") -end + @testset "Layers" begin + include("layers/basic.jl") + include("layers/normalisation.jl") + include("layers/stateless.jl") + include("layers/recurrent.jl") + include("layers/conv.jl") + include("layers/upsample.jl") + include("layers/show.jl") + end -@testset "CUDA" begin - if CUDA.functional() - include("cuda/runtests.jl") - else - @warn "CUDA unavailable, not testing GPU support" + @testset "outputsize" begin + using Flux: outputsize + include("outputsize.jl") + end + + @testset "CUDA" begin + if CUDA.functional() + include("cuda/runtests.jl") + else + @warn "CUDA unavailable, not testing GPU support" + end end -end -@static if VERSION == v"1.6" - using Documenter - @testset "Docs" begin - DocMeta.setdocmeta!(Flux, :DocTestSetup, :(using Flux); recursive=true) - doctest(Flux) + @static if VERSION == v"1.6" + using Documenter + @testset "Docs" begin + DocMeta.setdocmeta!(Flux, :DocTestSetup, :(using Flux); recursive=true) + doctest(Flux) + end end end From c3ec851cc6e07f4d31b87d7f7c1d44f106a9b34d Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Fri, 4 Feb 2022 21:56:58 -0500 Subject: [PATCH 6/7] fixup --- src/deprecations.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/src/deprecations.jl b/src/deprecations.jl index d25ea978ee..dd36e17d42 100644 --- a/src/deprecations.jl +++ b/src/deprecations.jl @@ -36,4 +36,3 @@ zeros32(::Type, dims...) = throw(ArgumentError("Flux.zeros32 is always Float32, # v0.13 deprecations -@deprecate Maxout(layers::Tuple) Maxout(layers...) \ No newline at end of file From deac30346ea860a9b3d365aeccb71fe938ce1f17 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Sat, 5 Feb 2022 11:48:10 -0500 Subject: [PATCH 7/7] fix tests for Flux.modules --- src/utils.jl | 7 ++++++- test/utils.jl | 23 ++++++++++++++--------- 2 files changed, 20 insertions(+), 10 deletions(-) diff --git a/src/utils.jl b/src/utils.jl index 09eadcac61..035798b5c0 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -775,15 +775,20 @@ Chain( # plus 2 non-trainable, 128 parameters, summarysize 200.312 KiB. julia> Flux.modules(m2) -5-element Vector{Any}: +7-element Vector{Any}: Chain(Chain(Dense(784, 64), BatchNorm(64, relu)), Dense(64, 10)) # 51_018 parameters, plus 128 non-trainable + (Chain(Dense(784, 64), BatchNorm(64, relu)), Dense(64, 10)) Chain(Dense(784, 64), BatchNorm(64, relu)) # 50_368 parameters, plus 128 non-trainable + (Dense(784, 64), BatchNorm(64, relu)) Dense(784, 64) # 50_240 parameters BatchNorm(64, relu) # 128 parameters, plus 128 non-trainable Dense(64, 10) # 650 parameters julia> L2(m) = sum(sum(abs2, l.weight) for l in Flux.modules(m) if l isa Dense) L2 (generic function with 1 method) + +julia> L2(m2) isa Float32 +true ``` """ modules(m) = [x for x in Functors.fcollect(m) if !isleaflike(x)] diff --git a/test/utils.jl b/test/utils.jl index 6b487e7854..1681a0df28 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -446,21 +446,26 @@ end m5 = Chain(m4, m2) modules = Flux.modules(m5) # Depth-first descent - @test length(modules) == 5 + @test length(modules) == 6 @test modules[1] === m5 - @test modules[2] === m4 - @test modules[3] === m1 - @test modules[4] === m2 - @test modules[5] === m3 + @test modules[3] === m4 + @test modules[4] === m1 + @test modules[5] === m2 + @test modules[6] === m3 - modules = Flux.modules(Chain(Dense(2,3), BatchNorm(3), LSTM(3,4))) - @test length(modules) == 5 + mod_par = Flux.modules(Parallel(Flux.Bilinear(2,2,2,cbrt), Dense(2,2,abs), Dense(2,2,abs2))) + @test length(mod_par) == 5 - modules = Flux.modules(Chain(SkipConnection( + mod_rnn = Flux.modules(Chain(Dense(2,3), BatchNorm(3), LSTM(3,4))) + @test length(mod_rnn) == 6 + @test mod_rnn[end] isa Flux.LSTMCell + + mod_skip = Flux.modules(Chain(SkipConnection( Conv((2,3), 4=>5; pad=6, stride=7), +), LayerNorm(8))) - @test length(modules) == 5 + @test length(mod_skip) == 6 + @test mod_skip[end] isa Flux.Diagonal end @testset "Patience triggers" begin