Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix ADAMW, and track the loss #46

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 10 additions & 6 deletions src/rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -392,8 +392,8 @@ end
"""
ADAMW(η = 1f-3, β = (9f-1, 9.99f-1), γ = 0, ϵ = eps(typeof(η)))

[ADAMW](https://arxiv.org/abs/1711.05101) is a variant of ADAM fixing (as in repairing) its
weight decay regularization.
[ADAMW](https://arxiv.org/abs/1711.05101) is a variant of [ADAM](@ref) fixing
(as in repairing) its weight decay regularization.

# Parameters
- Learning rate (`η`): Amount by which gradients are discounted before updating
Expand All @@ -405,7 +405,7 @@ weight decay regularization.
(no need to change default)
"""
ADAMW(η = 1f-3, β = (9f-1, 9.99f-1), γ = 0, ϵ = eps(typeof(η))) =
OptimiserChain(ADAM{typeof(η)}(η, β, ϵ), WeightDecay))
OptimiserChain(ADAM{typeof(η)}(1, β, ϵ), WeightDecay{typeof(η)}(γ), Descent(η))
Copy link
Member

@ToucheSir ToucheSir Jan 30, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I just remembered a previous comment (maybe on Slack? Can't find it now) which noted that our implementation did not match the paper:
image.

TL;DR is that (possibly because of an odd choice of variable names in the paper, they use α instead of η for the learning rate and η for something else), the "fixed" version is wrong and the original is correct. Apologies for the noise.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh damn, Brian is right, FluxML/Flux.jl#1612 should be reverted, I got confused by the nomenclature

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll file a PR in Flux


"""
AdaBelief(η = 1f-3, β = (9f-1, 9.99f-1), ϵ = eps(typeof(η)))
Expand Down Expand Up @@ -444,20 +444,24 @@ end
"""
WeightDecay(γ = 5f-4)

Decay weights by `γ`.
Decay weights by ``γ``, that is, add `γ .* x` to the gradient `x̄` which will be
subtracted from `x`.

Typically composed with other optimizers as the first transformation to the gradient,
making it equivalent to adding ``L_2`` regularization with coefficient ``γ`` to the loss.

# Parameters
- Weight decay (`γ`): Decay applied to weights during optimisation.
"""
struct WeightDecay{T}
wd::T
gamma::T
end
WeightDecay() = WeightDecay(5f-4)

init(o::WeightDecay, x::AbstractArray) = nothing

function apply!(o::WeightDecay, state, x, dx)
dx′ = @.. dx + o.wd * x
dx′ = @.. dx + o.gamma * x

return state, dx′
end
Expand Down
46 changes: 41 additions & 5 deletions test/rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,16 @@ RULES = [
name(o) = typeof(o).name.name
name(o::OptimiserChain) = join(name.(o.opts), " → ")

LOG = Dict()

loggradient(o) = (f, xs...) -> begin
y, dxs = Zygote.withgradient(f, xs...)
push!(get!(() -> Float32[], LOG, name(o)), y)
dxs # save the loss, return the gradient
end

@testset "independence" begin
empty!(LOG)
@testset "$(name(o))" for o in RULES
w = randn(10, 10)
w′ = randn(10, 10)
Expand All @@ -27,19 +36,20 @@ name(o::OptimiserChain) = join(name.(o.opts), " → ")
st = Optimisers.setup(o, w)
for t = 1:10^5
x = rand(10)
gs = gradient(w -> iloss(x, w, w′), w)
gs = loggradient(o)(w -> iloss(x, w, w′), w)
st, w = Optimisers.update!(st, w, gs...)
end
@test iloss(rand(10, 10), w, w′) < 0.01
end
end

@testset verbose=true "simple sum" begin
empty!(LOG)
@testset "$(name(o))" for o in RULES
m = shuffle!(reshape(1:64, 8, 8) .+ 0.0)
s = Optimisers.setup(o, m)
for _ in 1:10^5
g = gradient(x -> sum(abs2, x + x'), m)[1]
g = loggradient(o)(x -> sum(abs2, x + x'), m)[1]
s, m = Optimisers.update!(s, m, g)
end
# @test sum(m) < sum(1:64)
Expand All @@ -52,15 +62,27 @@ end
end
end

#=
plot(LOG[:ADAGrad]) # decline
LOG[:ADAGrad][end] # 3869.4075f0

plot(LOG[:AMSGrad]) # decline
LOG[:AMSGrad][end] # 2742.004f0

findfirst(isnan, LOG[:ADADelta]) # 182
plot(LOG[:ADADelta][1:100], yaxis=:log10) # exp growth
=#

@testset "original" begin
empty!(LOG)
@testset "$(name(o))" for o in RULES
w′ = (α = rand(3, 3), β = rand(3, 3))
w = (α = 5rand(3, 3), β = rand(3, 3))
st = Optimisers.setup(o, w)
loss(x, y) = mean((x.α .* x.β .- y.α .* y.β) .^ 2)
@test loss(w, w′) > 1
@test loss(w, w′) > 1 # guard against accidentally having loss 0
for i = 1:10^4
gs = gradient(x -> loss(x, w′), w)
gs = loggradient(o)(x -> loss(x, w′), w)
st, w = Optimisers.update(st, w, gs...)
end
lw = loss(w, w′)
Expand All @@ -73,7 +95,13 @@ end
end
end

#=
findfirst(isnan, LOG[:ADADelta]) # 11
plot(LOG[:ADADelta][1:12], yaxis=:log10)
=#

@testset verbose=true "StaticArrays" begin
empty!(LOG)
@testset "$(name(o))" for o in RULES
W1 = @SMatrix randn(10, 10)
b1 = @SVector randn(10)
Expand All @@ -87,7 +115,7 @@ end
@test s_loss(model, x, y) > 10
state = Optimisers.setup(o, model)
for t = 1:10^3
g = gradient(m -> s_loss(m, x, y), model)[1]
g = loggradient(o)(m -> s_loss(m, x, y), model)[1]
state, model = Optimisers.update!(state, model, g)
end
if o isa Union{Descent, RMSProp, ADAGrad, ADADelta, NADAM}
Expand All @@ -99,6 +127,14 @@ end
end
end

#=
plot(LOG[:Descent]) # decline
plot(LOG[:RMSProp]) # const 10^11
plot(LOG[:ADAGrad]) # const 10^8
plot(LOG[:ADADelta][1:30], yaxis=:log10) # exp growth
plot(LOG[name(NADAM())]) # stuck at 10^11
=#

@testset verbose=true "element types" begin
@testset "$(name(o))" for o in RULES
marray = (Float32[1,2], Float64[3,4], Float16[5,6])
Expand Down