From 28c15061bb9409f6999a366d40d9a7920bc25bfc Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Sun, 30 Jan 2022 10:30:34 -0500 Subject: [PATCH 1/2] log the loss during tests --- test/rules.jl | 39 +++++++++++++++++++++++++++++++++++---- 1 file changed, 35 insertions(+), 4 deletions(-) diff --git a/test/rules.jl b/test/rules.jl index c950aeb5..cffc1062 100644 --- a/test/rules.jl +++ b/test/rules.jl @@ -18,7 +18,16 @@ RULES = [ name(o) = typeof(o).name.name name(o::OptimiserChain) = join(name.(o.opts), " → ") +LOG = Dict() + +loggradient(o) = (f, xs...) -> begin + y, dxs = Zygote.withgradient(f, xs...) + push!(get!(() -> Float32[], LOG, name(o)), y) + dxs # save the loss, return the gradient +end + @testset "independence" begin + empty!(LOG) @testset "$(name(o))" for o in RULES w = randn(10, 10) w′ = randn(10, 10) @@ -27,7 +36,7 @@ name(o::OptimiserChain) = join(name.(o.opts), " → ") st = Optimisers.setup(o, w) for t = 1:10^5 x = rand(10) - gs = gradient(w -> iloss(x, w, w′), w) + gs = loggradient(o)(w -> iloss(x, w, w′), w) st, w = Optimisers.update!(st, w, gs...) end @test iloss(rand(10, 10), w, w′) < 0.01 @@ -35,11 +44,12 @@ name(o::OptimiserChain) = join(name.(o.opts), " → ") end @testset verbose=true "simple sum" begin + empty!(LOG) @testset "$(name(o))" for o in RULES m = shuffle!(reshape(1:64, 8, 8) .+ 0.0) s = Optimisers.setup(o, m) for _ in 1:10^5 - g = gradient(x -> sum(abs2, x + x'), m)[1] + g = loggradient(o)(x -> sum(abs2, x + x'), m)[1] s, m = Optimisers.update!(s, m, g) end # @test sum(m) < sum(1:64) @@ -52,7 +62,19 @@ end end end +#= +plot(LOG[:ADAGrad]) # decline +LOG[:ADAGrad][end] # 3869.4075f0 + +plot(LOG[:AMSGrad]) # decline +LOG[:AMSGrad][end] # 2742.004f0 + +findfirst(isnan, LOG[:ADADelta]) # 182 +plot(LOG[:ADADelta][1:100], yaxis=:log10) # exp growth +=# + @testset "original" begin + empty!(LOG) @testset "$(name(o))" for o in RULES w′ = (α = rand(3, 3), β = rand(3, 3)) w = (α = 5rand(3, 3), β = rand(3, 3)) @@ -60,7 +82,7 @@ end loss(x, y) = mean((x.α .* x.β .- y.α .* y.β) .^ 2) @test loss(w, w′) > 1 for i = 1:10^4 - gs = gradient(x -> loss(x, w′), w) + gs = loggradient(o)(x -> loss(x, w′), w) st, w = Optimisers.update(st, w, gs...) end lw = loss(w, w′) @@ -74,6 +96,7 @@ end end @testset verbose=true "StaticArrays" begin + empty!(LOG) @testset "$(name(o))" for o in RULES W1 = @SMatrix randn(10, 10) b1 = @SVector randn(10) @@ -87,7 +110,7 @@ end @test s_loss(model, x, y) > 10 state = Optimisers.setup(o, model) for t = 1:10^3 - g = gradient(m -> s_loss(m, x, y), model)[1] + g = loggradient(o)(m -> s_loss(m, x, y), model)[1] state, model = Optimisers.update!(state, model, g) end if o isa Union{Descent, RMSProp, ADAGrad, ADADelta, NADAM} @@ -99,6 +122,14 @@ end end end +#= +plot(LOG[:Descent]) # decline +plot(LOG[:RMSProp]) # const 10^11 +plot(LOG[:ADAGrad]) # const 10^8 +plot(LOG[:ADADelta][1:30], yaxis=:log10) # exp growth +plot(LOG[name(NADAM())]) # stuck at 10^11 +=# + @testset verbose=true "element types" begin @testset "$(name(o))" for o in RULES marray = (Float32[1,2], Float64[3,4], Float16[5,6]) From 9293965949cc5447ad98a20c46c58aa064821396 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Sun, 30 Jan 2022 10:47:39 -0500 Subject: [PATCH 2/2] fixup ADAMW --- src/rules.jl | 16 ++++++++++------ test/rules.jl | 7 ++++++- 2 files changed, 16 insertions(+), 7 deletions(-) diff --git a/src/rules.jl b/src/rules.jl index f2e1ee5c..2ead189e 100644 --- a/src/rules.jl +++ b/src/rules.jl @@ -392,8 +392,8 @@ end """ ADAMW(η = 1f-3, β = (9f-1, 9.99f-1), γ = 0, ϵ = eps(typeof(η))) -[ADAMW](https://arxiv.org/abs/1711.05101) is a variant of ADAM fixing (as in repairing) its -weight decay regularization. +[ADAMW](https://arxiv.org/abs/1711.05101) is a variant of [ADAM](@ref) fixing +(as in repairing) its weight decay regularization. # Parameters - Learning rate (`η`): Amount by which gradients are discounted before updating @@ -405,7 +405,7 @@ weight decay regularization. (no need to change default) """ ADAMW(η = 1f-3, β = (9f-1, 9.99f-1), γ = 0, ϵ = eps(typeof(η))) = - OptimiserChain(ADAM{typeof(η)}(η, β, ϵ), WeightDecay(γ)) + OptimiserChain(ADAM{typeof(η)}(1, β, ϵ), WeightDecay{typeof(η)}(γ), Descent(η)) """ AdaBelief(η = 1f-3, β = (9f-1, 9.99f-1), ϵ = eps(typeof(η))) @@ -444,20 +444,24 @@ end """ WeightDecay(γ = 5f-4) -Decay weights by `γ`. +Decay weights by ``γ``, that is, add `γ .* x` to the gradient `x̄` which will be +subtracted from `x`. + +Typically composed with other optimizers as the first transformation to the gradient, +making it equivalent to adding ``L_2`` regularization with coefficient ``γ`` to the loss. # Parameters - Weight decay (`γ`): Decay applied to weights during optimisation. """ struct WeightDecay{T} - wd::T + gamma::T end WeightDecay() = WeightDecay(5f-4) init(o::WeightDecay, x::AbstractArray) = nothing function apply!(o::WeightDecay, state, x, dx) - dx′ = @.. dx + o.wd * x + dx′ = @.. dx + o.gamma * x return state, dx′ end diff --git a/test/rules.jl b/test/rules.jl index cffc1062..74290848 100644 --- a/test/rules.jl +++ b/test/rules.jl @@ -80,7 +80,7 @@ plot(LOG[:ADADelta][1:100], yaxis=:log10) # exp growth w = (α = 5rand(3, 3), β = rand(3, 3)) st = Optimisers.setup(o, w) loss(x, y) = mean((x.α .* x.β .- y.α .* y.β) .^ 2) - @test loss(w, w′) > 1 + @test loss(w, w′) > 1 # guard against accidentally having loss 0 for i = 1:10^4 gs = loggradient(o)(x -> loss(x, w′), w) st, w = Optimisers.update(st, w, gs...) @@ -95,6 +95,11 @@ plot(LOG[:ADADelta][1:100], yaxis=:log10) # exp growth end end +#= +findfirst(isnan, LOG[:ADADelta]) # 11 +plot(LOG[:ADADelta][1:12], yaxis=:log10) +=# + @testset verbose=true "StaticArrays" begin empty!(LOG) @testset "$(name(o))" for o in RULES