diff --git a/test/layers/normalisation.jl b/test/layers/normalisation.jl index b0a9a51845..1046248d13 100644 --- a/test/layers/normalisation.jl +++ b/test/layers/normalisation.jl @@ -207,7 +207,11 @@ end Δg = gswn[g] Δv = gswn[v] @test wnd(fake_data) ≈ d(fake_data) - @test sum(ΔW .* v ./ normv, dims = WN_dim...) ≈ Δg + if isa(WN_dim, Int) + @test sum(ΔW .* v ./ normv, dims = WN_dim) ≈ Δg + else + @test sum(ΔW .* v ./ normv, dims = WN_dim[1]) ≈ Δg + end @test g ./ normv .* ΔW - g .* Δg .* v ./ (normv.^2) ≈ Δv @test size(Δv) == size(ΔW) @test isa(wnd.layer.W, Flux.WeightNormParam)