Skip to content

Commit

Permalink
Fix bug in line search in DRLS (#53)
Browse files Browse the repository at this point in the history
  • Loading branch information
lostella authored Oct 21, 2021
1 parent 96fd0c2 commit b89725b
Showing 1 changed file with 27 additions and 26 deletions.
53 changes: 27 additions & 26 deletions src/algorithms/drls.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,30 +10,34 @@ using ProximalOperators: Zero
using LinearAlgebra
using Printf

function drls_default_gamma(f, Lf, alpha, lambda)
if ProximalOperators.is_convex(f)
return alpha / Lf
else
return alpha * (2 - lambda) / (2 * Lf)
end
end

function drls_default_c(f, Lf, gamma, lambda, beta)
m = if ProximalOperators.is_convex(f)
max(gamma * Lf - lambda / 2, 0)
else
1
end
C_gamma_lambda = (lambda / ((1 + gamma * Lf)^2) * ((2 - lambda) / 2 - gamma * Lf * m))
return beta * C_gamma_lambda
end

Base.@kwdef struct DRLSIteration{R,C<:Union{R,Complex{R}},Tx<:AbstractArray{C},Tf,Tg,TH}
f::Tf = Zero()
g::Tg = Zero()
x0::Tx
alpha::R = real(eltype(x0))(0.95)
beta::R = real(eltype(x0))(0.5)
Lf::Maybe{R} = nothing
gamma::Maybe{R} = begin
if ProximalOperators.is_convex(f)
alpha / Lf
else
alpha * (2 - lambda) / (2 * Lf)
end
end
lambda::R = real(eltype(x0))(1)
c::R = begin
m = if ProximalOperators.is_convex(f)
max(gamma * Lf - lambda / 2, 0)
else
1
end
C_gamma_lambda = (lambda / ((1 + gamma * Lf)^2) * ((2 - lambda) / 2 - gamma * Lf * m))
c = beta * C_gamma_lambda
end
Lf::Maybe{R} = nothing
gamma::R = drls_default_gamma(f, Lf, alpha, lambda)
c::R = drls_default_c(f, Lf, gamma, lambda, beta)
max_backtracks::Int = 20
H::TH = LBFGS(x0, 5)
end
Expand All @@ -58,12 +62,9 @@ Base.@kwdef mutable struct DRLSState{R,Tx,TH}
tau::Maybe{R} = nothing
end

function DRE(state::DRLSState)
return (
state.f_u + state.g_v - real(dot(state.x - state.u, state.res)) / state.gamma +
1 / (2 * state.gamma) * norm(state.res)^2
)
end
DRE(f_u::Number, g_v::Number, x, u, res, gamma) = f_u + g_v - real(dot(x - u, res)) / gamma + 1 / (2 * gamma) * norm(res)^2

DRE(state::DRLSState) = DRE(state.f_u, state.g_v, state.x, state.u, state.res, state.gamma)

function Base.iterate(iter::DRLSIteration)
x = copy(iter.x0)
Expand All @@ -84,8 +85,8 @@ function Base.iterate(iter::DRLSIteration{R}, state::DRLSState) where {R}

mul!(state.d, iter.H, -state.res)
state.x_d .= state.x .+ state.d
copyto!(state.xbar_prev, state.xbar)
copyto!(state.res_prev, state.res)
state.xbar_prev, state.xbar = state.xbar, state.xbar_prev
state.res_prev, state.res = state.res, state.res_prev
state.tau = R(1)
state.x .= state.x_d

Expand All @@ -102,7 +103,7 @@ function Base.iterate(iter::DRLSIteration{R}, state::DRLSState) where {R}
state.xbar .= state.x .- iter.lambda * state.res
DRE_candidate = DRE(state)

if DRE_candidate <= DRE_curr - iter.c / iter.gamma * norm(state.res)^2
if DRE_candidate <= DRE_curr - iter.c / iter.gamma * norm(state.res_prev)^2
return state, state
end

Expand Down

0 comments on commit b89725b

Please sign in to comment.