diff --git a/src/algorithms/drls.jl b/src/algorithms/drls.jl index 2054df4..06db489 100644 --- a/src/algorithms/drls.jl +++ b/src/algorithms/drls.jl @@ -10,30 +10,34 @@ using ProximalOperators: Zero using LinearAlgebra using Printf +function drls_default_gamma(f, Lf, alpha, lambda) + if ProximalOperators.is_convex(f) + return alpha / Lf + else + return alpha * (2 - lambda) / (2 * Lf) + end +end + +function drls_default_c(f, Lf, gamma, lambda, beta) + m = if ProximalOperators.is_convex(f) + max(gamma * Lf - lambda / 2, 0) + else + 1 + end + C_gamma_lambda = (lambda / ((1 + gamma * Lf)^2) * ((2 - lambda) / 2 - gamma * Lf * m)) + return beta * C_gamma_lambda +end + Base.@kwdef struct DRLSIteration{R,C<:Union{R,Complex{R}},Tx<:AbstractArray{C},Tf,Tg,TH} f::Tf = Zero() g::Tg = Zero() x0::Tx alpha::R = real(eltype(x0))(0.95) beta::R = real(eltype(x0))(0.5) - Lf::Maybe{R} = nothing - gamma::Maybe{R} = begin - if ProximalOperators.is_convex(f) - alpha / Lf - else - alpha * (2 - lambda) / (2 * Lf) - end - end lambda::R = real(eltype(x0))(1) - c::R = begin - m = if ProximalOperators.is_convex(f) - max(gamma * Lf - lambda / 2, 0) - else - 1 - end - C_gamma_lambda = (lambda / ((1 + gamma * Lf)^2) * ((2 - lambda) / 2 - gamma * Lf * m)) - c = beta * C_gamma_lambda - end + Lf::Maybe{R} = nothing + gamma::R = drls_default_gamma(f, Lf, alpha, lambda) + c::R = drls_default_c(f, Lf, gamma, lambda, beta) max_backtracks::Int = 20 H::TH = LBFGS(x0, 5) end @@ -58,12 +62,9 @@ Base.@kwdef mutable struct DRLSState{R,Tx,TH} tau::Maybe{R} = nothing end -function DRE(state::DRLSState) - return ( - state.f_u + state.g_v - real(dot(state.x - state.u, state.res)) / state.gamma + - 1 / (2 * state.gamma) * norm(state.res)^2 - ) -end +DRE(f_u::Number, g_v::Number, x, u, res, gamma) = f_u + g_v - real(dot(x - u, res)) / gamma + 1 / (2 * gamma) * norm(res)^2 + +DRE(state::DRLSState) = DRE(state.f_u, state.g_v, state.x, state.u, state.res, state.gamma) function Base.iterate(iter::DRLSIteration) x = copy(iter.x0) @@ -84,8 +85,8 @@ function Base.iterate(iter::DRLSIteration{R}, state::DRLSState) where {R} mul!(state.d, iter.H, -state.res) state.x_d .= state.x .+ state.d - copyto!(state.xbar_prev, state.xbar) - copyto!(state.res_prev, state.res) + state.xbar_prev, state.xbar = state.xbar, state.xbar_prev + state.res_prev, state.res = state.res, state.res_prev state.tau = R(1) state.x .= state.x_d @@ -102,7 +103,7 @@ function Base.iterate(iter::DRLSIteration{R}, state::DRLSState) where {R} state.xbar .= state.x .- iter.lambda * state.res DRE_candidate = DRE(state) - if DRE_candidate <= DRE_curr - iter.c / iter.gamma * norm(state.res)^2 + if DRE_candidate <= DRE_curr - iter.c / iter.gamma * norm(state.res_prev)^2 return state, state end