From 5ec25bdd576c59716b54ff117b4a22f73d6498e4 Mon Sep 17 00:00:00 2001 From: an-awesome-guy Date: Thu, 16 Jan 2020 19:35:52 +0530 Subject: [PATCH] Added Weight Normalization --- src/Flux.jl | 2 +- src/layers/basic.jl | 45 +++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 46 insertions(+), 1 deletion(-) diff --git a/src/Flux.jl b/src/Flux.jl index 9969b32346..30b4193255 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -11,7 +11,7 @@ export gradient export Chain, Dense, Maxout, RNN, LSTM, GRU, Conv, CrossCor, ConvTranspose, MaxPool, MeanPool, DepthwiseConv, Dropout, AlphaDropout, LayerNorm, BatchNorm, InstanceNorm, GroupNorm, - SkipConnection, params, fmap, cpu, gpu, f32, f64 + SkipConnection, params, fmap, cpu, gpu, f32, f64, DenseWN include("optimise/Optimise.jl") using .Optimise diff --git a/src/layers/basic.jl b/src/layers/basic.jl index 2a46520818..06f693feb5 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -226,3 +226,48 @@ end function Base.show(io::IO, b::SkipConnection) print(io, "SkipConnection(", b.layers, ", ", b.connection, ")") end + +""" +Creates a `DenseWN` layer with parameters `v`,'g' and `b`. +(A 'Dense' reparameterized ) + +y = σ.(W * x .+ b) where W = g*(v/||v||) + + +The input `x` must be a vector of length `in`, or a batch of vectors represented +as an `in × N` matrix. The out `y` will be a vector or batch of length `out`. + +```julia +julia> c = DenseWN(5, 4); + +julia> c(rand(5)) +4-element Array{Float64,1}: + 0.31057365817351623 + 0.5403848812373727 + 0.07163342873445384 + 0.1526865185646453 +## References +[Weight Normalization](https://arxiv.org/abs/1602.07868) + A Simple Reparameterization to Accelerate Training of Deep Neural Networks. +""" + +struct DenseWN{F,S,T} + v::S + g::T + b::T + σ::F +end + +DenseWN(W, b) = DenseWN(W, b, identity) + +function DenseWN(in::Integer, out::Integer, σ = identity; + initv = glorot_uniform,initg = ones, initb = zeros) + return DenseWN(initv(out, in),initg(1) ,initb(out), σ) +end + +@functor DenseWN + +function (a::DenseWN)(x::AbstractArray) + v, g, b, σ = a.v, a.g ,a.b , a.σ + σ.((g.*v*x/√sum(v.^2)) .+ b) +end