From deac30346ea860a9b3d365aeccb71fe938ce1f17 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Sat, 5 Feb 2022 11:48:10 -0500 Subject: [PATCH] fix tests for Flux.modules --- src/utils.jl | 7 ++++++- test/utils.jl | 23 ++++++++++++++--------- 2 files changed, 20 insertions(+), 10 deletions(-) diff --git a/src/utils.jl b/src/utils.jl index 09eadcac61..035798b5c0 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -775,15 +775,20 @@ Chain( # plus 2 non-trainable, 128 parameters, summarysize 200.312 KiB. julia> Flux.modules(m2) -5-element Vector{Any}: +7-element Vector{Any}: Chain(Chain(Dense(784, 64), BatchNorm(64, relu)), Dense(64, 10)) # 51_018 parameters, plus 128 non-trainable + (Chain(Dense(784, 64), BatchNorm(64, relu)), Dense(64, 10)) Chain(Dense(784, 64), BatchNorm(64, relu)) # 50_368 parameters, plus 128 non-trainable + (Dense(784, 64), BatchNorm(64, relu)) Dense(784, 64) # 50_240 parameters BatchNorm(64, relu) # 128 parameters, plus 128 non-trainable Dense(64, 10) # 650 parameters julia> L2(m) = sum(sum(abs2, l.weight) for l in Flux.modules(m) if l isa Dense) L2 (generic function with 1 method) + +julia> L2(m2) isa Float32 +true ``` """ modules(m) = [x for x in Functors.fcollect(m) if !isleaflike(x)] diff --git a/test/utils.jl b/test/utils.jl index 6b487e7854..1681a0df28 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -446,21 +446,26 @@ end m5 = Chain(m4, m2) modules = Flux.modules(m5) # Depth-first descent - @test length(modules) == 5 + @test length(modules) == 6 @test modules[1] === m5 - @test modules[2] === m4 - @test modules[3] === m1 - @test modules[4] === m2 - @test modules[5] === m3 + @test modules[3] === m4 + @test modules[4] === m1 + @test modules[5] === m2 + @test modules[6] === m3 - modules = Flux.modules(Chain(Dense(2,3), BatchNorm(3), LSTM(3,4))) - @test length(modules) == 5 + mod_par = Flux.modules(Parallel(Flux.Bilinear(2,2,2,cbrt), Dense(2,2,abs), Dense(2,2,abs2))) + @test length(mod_par) == 5 - modules = Flux.modules(Chain(SkipConnection( + mod_rnn = Flux.modules(Chain(Dense(2,3), BatchNorm(3), LSTM(3,4))) + @test length(mod_rnn) == 6 + @test mod_rnn[end] isa Flux.LSTMCell + + mod_skip = Flux.modules(Chain(SkipConnection( Conv((2,3), 4=>5; pad=6, stride=7), +), LayerNorm(8))) - @test length(modules) == 5 + @test length(mod_skip) == 6 + @test mod_skip[end] isa Flux.Diagonal end @testset "Patience triggers" begin