-
-
Notifications
You must be signed in to change notification settings - Fork 612
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added WeightNorm #1005
Added WeightNorm #1005
Conversation
|
I don't know why convs are not working with your code, but this alternative solution should work. struct WeightNorm
layer
g
v
weight::Symbol
dims
eps
end
WN_mag(p, dims) = sqrt.(sum(abs2.(p), dims = dims))
WN_dir(p, mag, eps=eps(eltype(p))) = p ./ (mag .+ eps)
WN_reconstr(wn::WeightNorm) = wn.g .* wn.v ./ WN_mag(wn.v, wn.dims)
function WeightNorm(layer, weight::Union{Symbol,Int}; dims)
#Expose layer fields and constructor
func, re = Flux.functor(layer)
#Get the fields
w = getfield(layer, weight)
g = WN_mag(w, dims)
v = WN_dir(w, g)
# Reconstruct the layer changing w for v (let's not waste memeory)
replace(name) = name == weight ? v : getfield(layer, name)
par = [replace(name) for name in keys(func)]
WeightNorm(re(par), g, v, weight, dims, eps(Float32))
end
function (wn::WeightNorm)(x)
func, re = Flux.functor(wn.layer)
w = WN_reconstr(wn)
replace(name) = name == wn.weight ? w : getfield(wn.layer, name)
par = [replace(name) for name in keys(func)]
re(par)(x)
end
Flux.@functor WeightNorm This approach seems simpler, we don't have to define custom arrays. julia> m = Flux.Dense(2,3)
Dense(2, 3)
julia> wn = FluxDNN.WeightNorm(m, :W, dims=1)
FluxDNN.WeightNorm(Dense(2, 3), Float32[1.1312975 0.5932242], Float32[-0.2702223 -0.22867282; -0.23115203 -0.31670466; 0.93463814 0.9205468], :W, 1, 1.1920929f-7)
## FORWARDS ARE THE SAME
julia> m(ones(2))
3-element Array{Float32,1}:
-0.44135612
-0.44937864
1.6034446
julia> wn(ones(2))
3-element Array{Float32,1}:
-0.44135612
-0.44937867
1.6034446
#### BACKWARDS
julia> Flux.gradient(params(m)) do
sum(m(ones(2)).^2)
end.grads
IdDict{Any,Any} with 4 entries:
RefValue{typeof(^)}(^) => RefValue{Any}((x = nothing,))
RefValue{Val{2}}(Val{2}()) => RefValue{Any}((x = nothing,))
Float32[-0.305702 -0.135654… => Float32[-0.882712 -0.882712; -0.898757 -0.898757; 3.20…
Float32[0.0, 0.0, 0.0] => Float32[-0.882712, -0.898757, 3.20689]
julia> Flux.gradient(params(wn)) do
sum(wn(ones(2)).^2)
end.grads
IdDict{Any,Any} with 5 entries:
RefValue{typeof(^)}(^) => RefValue{Any}((x = nothing,))
Float32[1.1313 0.593224] => Float32[3.44356 3.43859]
Float32[-0.270222 -0.228673… => Float32[0.0540923 -0.0571875; -0.116265 0.112866; -0.0…
RefValue{Val{2}}(Val{2}()) => RefValue{Any}((x = nothing,))
Float32[0.0, 0.0, 0.0] => Float32[-1.76542, -1.79751, 6.41378] |
Also convolutions seem fine. Didn't test it on gpu yet. julia> m = Flux.Conv((2,2), 2=>3)
Conv((2, 2), 2=>3)
julia> wn = FluxDNN.WeightNorm(m, :weight, dims=1);
julia> m(ones(3,3,2,2))
2×2×3×2 Array{Float64,4}:
[:, :, 1, 1] =
0.0310524 0.0310524
0.0310524 0.0310524
[:, :, 2, 1] =
0.676353 0.676353
0.676353 0.676353
[:, :, 3, 1] =
-0.733513 -0.733513
-0.733513 -0.733513
[:, :, 1, 2] =
0.0310524 0.0310524
0.0310524 0.0310524
[:, :, 2, 2] =
0.676353 0.676353
0.676353 0.676353
[:, :, 3, 2] =
-0.733513 -0.733513
-0.733513 -0.733513
julia> wn(ones(3,3,2,2))
2×2×3×2 Array{Float64,4}:
[:, :, 1, 1] =
0.0310524 0.0310524
0.0310524 0.0310524
[:, :, 2, 1] =
0.676353 0.676353
0.676353 0.676353
[:, :, 3, 1] =
-0.733513 -0.733513
-0.733513 -0.733513
[:, :, 1, 2] =
0.0310524 0.0310524
0.0310524 0.0310524
[:, :, 2, 2] =
0.676353 0.676353
0.676353 0.676353
[:, :, 3, 2] =
-0.733513 -0.733513
-0.733513 -0.733513
julia> Flux.gradient(params(m)) do
sum(m(ones(3,3,2,2)).^2)
end.grads
IdDict{Any,Any} with 4 entries:
Float32[0.0, 0.0, 0.0] => [0.496838, 10.8217, -11.7362]
Float32[0.429876 -0.0354687… => [0.496838 0.496838; 0.496838 0.496838]…
RefValue{typeof(^)}(^) => RefValue{Any}((x = nothing,))
RefValue{Val{2}}(Val{2}()) => RefValue{Any}((x = nothing,))
julia> Flux.gradient(params(wn)) do
sum(wn(ones(3,3,2,2)).^2)
end.grads
IdDict{Any,Any} with 5 entries:
Float32[0.653553 -0.241225;… => [0.348865 0.0517; 0.301239 -0.0128509]…
RefValue{typeof(^)}(^) => RefValue{Any}((x = nothing,))
RefValue{Val{2}}(Val{2}()) => RefValue{Any}((x = nothing,))
Float32[0.0, 0.0, 0.0] => [0.993675, 21.6433, -23.4724]
Float32[0.657753 0.147035]… => [-0.0513368 -0.602016]… |
I thought of something similar, but reconstructing the layer at every call sounded like it could incur a hefty penalty to performance. We can benchmark the implementations, and also use the tests we already have for the gradients to compare them later. |
I don't think constructing an object out of precomputed fields is expensive, at least not compared to the weight reconstruction that we have to perform in any case. In any case, let's do some benchmarks. |
@bhvieira would you like to implement #1005 (comment) here or you want me to open a separate PR? |
We should really benchmark both before that, but feel free to open a PR to my branch at bhvieira:weightnorm, I really like that model of credit assignment to the work. That way we both get to keep authorship of the commits in this PR. |
@CarloLucibello I ran the benchmarks, the backward passes don't look that much different (only 2 times slower), but the forward pass is around 5 times slower than mine (for comparison, mine is already 3 times slower than the normal layer without WeightNorm).
|
Gradients are correct though, perhaps it's only a matter of optimizing your solution as it appears to work with any layer out of the box:
|
mhm, I wonder what is causing this overhead. Could you try with a larger layers, e.g. m = Flux.Dense(10,10); and m = Flux.Dense(100,100); hopefully the overhead with respect to the current PR should decrease |
The new tests should pass, but I identified a big problem right now. where WeightNorm refuses to work with the bias parameters of Dense. |
I'll try that later when I find some time @CarloLucibello |
WeightNorm for several params, single dim Test for Scalar and Vector dims Test newly created WN equality Simplified some bits Missing last constructor
Can't make the following example work no matter what I try:
I can envision what could cause WeightNorm to fail, but I have no clue where does mutating occur.
|
I'm closing this PR, I haven't advanced on it and I'm out of time these days. |
In light of #993 I wanted to create a WeightNorm constructor that could operate over any layer, param and dim.
This proved harder than I thought, but with the help of @chengchingwen and @mcabbott I could finally make this work.
So the execution is simply this:
WeightNorm
on a layerWeightNormWeight
functor
functionality with the substitution in placeThe catch is that it is not ready though.
- [ ] Make it work with(I really want to make this work, but now I noticed it's not only Conv that will be laborious to fix, simple stuff like RNNs will also require some iterations to get working simply because they hide their params inside cells)NNlib.conv
Base
array operations. But I couldn't subtypeWeightNormWeight
intoAbstractArray
eitherweight
at a time, potentially with differingdim
as well.So, here's a first try at least. I could throw an error to
Conv
, but I want to make it work and perhaps it'll be better to wait out until everything is right.Inputs welcome!