From ad6586808bf9ce68cc9fc39cf2cb94344727402d Mon Sep 17 00:00:00 2001 From: Lorenzo Stella Date: Thu, 21 Oct 2021 11:52:51 +0200 Subject: [PATCH 1/3] fix bug in drls --- src/algorithms/drls.jl | 35 +++++++++++++++++++++++++++-------- 1 file changed, 27 insertions(+), 8 deletions(-) diff --git a/src/algorithms/drls.jl b/src/algorithms/drls.jl index 2054df4..a6eaf4c 100644 --- a/src/algorithms/drls.jl +++ b/src/algorithms/drls.jl @@ -58,13 +58,32 @@ Base.@kwdef mutable struct DRLSState{R,Tx,TH} tau::Maybe{R} = nothing end -function DRE(state::DRLSState) - return ( - state.f_u + state.g_v - real(dot(state.x - state.u, state.res)) / state.gamma + - 1 / (2 * state.gamma) * norm(state.res)^2 - ) +DRE(f_u::Number, g_v::Number, x, u, res, gamma) = f_u + g_v - real(dot(x - u, res)) / gamma + 1 / (2 * gamma) * norm(res)^2 + +DRE(state::DRLSState) = DRE(state.f_u, state.g_v, state.x, state.u, state.res, state.gamma) + +function DRE(f, g, x, gamma) + u, f_u = prox(f, x, gamma) + w = 2 * u - x + v, g_v = prox(g, w, gamma) + res = u - v + return DRE(f_u, g_v, x, u, res, gamma) end +# function DRS_step(f, g, x, gamma, lambda) +# u, f_u = prox(f, x, gamma) +# w = 2 * u - x +# v, g_v = prox(g, w, gamma) +# res = u - v +# return x - lambda * res, res +# end + +# function DRS_consistency_check(iter, state) +# xbar, res = DRS_step(iter.f, iter.g, state.x, iter.gamma, iter.lambda) +# @assert all(res .== state.res) +# @assert all(xbar .== state.xbar) +# end + function Base.iterate(iter::DRLSIteration) x = copy(iter.x0) u, f_u = prox(iter.f, x, iter.gamma) @@ -84,8 +103,8 @@ function Base.iterate(iter::DRLSIteration{R}, state::DRLSState) where {R} mul!(state.d, iter.H, -state.res) state.x_d .= state.x .+ state.d - copyto!(state.xbar_prev, state.xbar) - copyto!(state.res_prev, state.res) + state.xbar_prev, state.xbar = state.xbar, state.xbar_prev + state.res_prev, state.res = state.res, state.res_prev state.tau = R(1) state.x .= state.x_d @@ -102,7 +121,7 @@ function Base.iterate(iter::DRLSIteration{R}, state::DRLSState) where {R} state.xbar .= state.x .- iter.lambda * state.res DRE_candidate = DRE(state) - if DRE_candidate <= DRE_curr - iter.c / iter.gamma * norm(state.res)^2 + if DRE_candidate <= DRE_curr - iter.c / iter.gamma * norm(state.res_prev)^2 return state, state end From de5c6434a6340f3dbb676eb15ec928a1401fe7a2 Mon Sep 17 00:00:00 2001 From: Lorenzo Stella Date: Thu, 21 Oct 2021 11:58:25 +0200 Subject: [PATCH 2/3] remove garbage --- src/algorithms/drls.jl | 22 ---------------------- 1 file changed, 22 deletions(-) diff --git a/src/algorithms/drls.jl b/src/algorithms/drls.jl index a6eaf4c..be9a4ab 100644 --- a/src/algorithms/drls.jl +++ b/src/algorithms/drls.jl @@ -62,28 +62,6 @@ DRE(f_u::Number, g_v::Number, x, u, res, gamma) = f_u + g_v - real(dot(x - u, re DRE(state::DRLSState) = DRE(state.f_u, state.g_v, state.x, state.u, state.res, state.gamma) -function DRE(f, g, x, gamma) - u, f_u = prox(f, x, gamma) - w = 2 * u - x - v, g_v = prox(g, w, gamma) - res = u - v - return DRE(f_u, g_v, x, u, res, gamma) -end - -# function DRS_step(f, g, x, gamma, lambda) -# u, f_u = prox(f, x, gamma) -# w = 2 * u - x -# v, g_v = prox(g, w, gamma) -# res = u - v -# return x - lambda * res, res -# end - -# function DRS_consistency_check(iter, state) -# xbar, res = DRS_step(iter.f, iter.g, state.x, iter.gamma, iter.lambda) -# @assert all(res .== state.res) -# @assert all(xbar .== state.xbar) -# end - function Base.iterate(iter::DRLSIteration) x = copy(iter.x0) u, f_u = prox(iter.f, x, iter.gamma) From 5aeabe556139bcb4fbd440d150e1e839dab0f93b Mon Sep 17 00:00:00 2001 From: Lorenzo Stella Date: Thu, 21 Oct 2021 13:44:17 +0200 Subject: [PATCH 3/3] extract default settings --- src/algorithms/drls.jl | 38 +++++++++++++++++++++----------------- 1 file changed, 21 insertions(+), 17 deletions(-) diff --git a/src/algorithms/drls.jl b/src/algorithms/drls.jl index be9a4ab..06db489 100644 --- a/src/algorithms/drls.jl +++ b/src/algorithms/drls.jl @@ -10,30 +10,34 @@ using ProximalOperators: Zero using LinearAlgebra using Printf +function drls_default_gamma(f, Lf, alpha, lambda) + if ProximalOperators.is_convex(f) + return alpha / Lf + else + return alpha * (2 - lambda) / (2 * Lf) + end +end + +function drls_default_c(f, Lf, gamma, lambda, beta) + m = if ProximalOperators.is_convex(f) + max(gamma * Lf - lambda / 2, 0) + else + 1 + end + C_gamma_lambda = (lambda / ((1 + gamma * Lf)^2) * ((2 - lambda) / 2 - gamma * Lf * m)) + return beta * C_gamma_lambda +end + Base.@kwdef struct DRLSIteration{R,C<:Union{R,Complex{R}},Tx<:AbstractArray{C},Tf,Tg,TH} f::Tf = Zero() g::Tg = Zero() x0::Tx alpha::R = real(eltype(x0))(0.95) beta::R = real(eltype(x0))(0.5) - Lf::Maybe{R} = nothing - gamma::Maybe{R} = begin - if ProximalOperators.is_convex(f) - alpha / Lf - else - alpha * (2 - lambda) / (2 * Lf) - end - end lambda::R = real(eltype(x0))(1) - c::R = begin - m = if ProximalOperators.is_convex(f) - max(gamma * Lf - lambda / 2, 0) - else - 1 - end - C_gamma_lambda = (lambda / ((1 + gamma * Lf)^2) * ((2 - lambda) / 2 - gamma * Lf * m)) - c = beta * C_gamma_lambda - end + Lf::Maybe{R} = nothing + gamma::R = drls_default_gamma(f, Lf, alpha, lambda) + c::R = drls_default_c(f, Lf, gamma, lambda, beta) max_backtracks::Int = 20 H::TH = LBFGS(x0, 5) end