Skip to content

Commit

Permalink
Add implementation of S-FISTA (#48)
Browse files Browse the repository at this point in the history
  • Loading branch information
wwkong authored Sep 24, 2021
1 parent 2f5d17d commit cf6f7e9
Show file tree
Hide file tree
Showing 6 changed files with 202 additions and 1 deletion.
10 changes: 9 additions & 1 deletion benchmark/benchmarks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ for T in [Float64]
f = LeastSquares($A, $b)
g = NormL1($lam)
end

SUITE[k]["AFBA-1"] = @benchmarkable solver(x0, y0, f=f, g=g, beta_f=beta_f) setup=begin
beta_f = opnorm($A)^2
solver = ProximalAlgorithms.AFBA(theta=$R(1), mu=$R(1), tol=$R(1e-6))
Expand All @@ -73,4 +73,12 @@ for T in [Float64]
h = Translate(SqrNormL2(), -$b)
g = NormL1($lam)
end

SUITE[k]["SFISTA"] = @benchmarkable solver(y0, f=f, Lf=Lf, h=h) setup=begin
solver = ProximalAlgorithms.SFISTA(tol=$R(1e-3))
y0 = zeros($T, size($A, 2))
f = LeastSquares($A, $b)
h = NormL1($lam)
Lf = opnorm($A)^2
end
end
1 change: 1 addition & 0 deletions src/ProximalAlgorithms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,5 +29,6 @@ include("algorithms/drls.jl")
include("algorithms/primaldual.jl")
include("algorithms/davisyin.jl")
include("algorithms/lilin.jl")
include("algorithms/fista.jl")

end # module
129 changes: 129 additions & 0 deletions src/algorithms/fista.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
# An implementation of a FISTA-like method, where the smooth part of the objective function can be strongly convex.

using Base.Iterators
using ProximalAlgorithms.IterationTools
using ProximalOperators: Zero
using LinearAlgebra
using Printf

"""
SFISTAIteration(; <keyword-arguments>)
Instantiate the FISTA-like method in [3] for solving strongly-convex composite optimization problems of the form
minimize f(x) + h(x),
where h is proper closed convex and f is a continuously differentiable function that is μ-strongly convex and whose gradient is
Lf-Lipschitz continuous.
The scheme is based on Nesterov's accelerated gradient method [1, Eq. (4.9)] and Beck's method for the convex case [2]. Its full
definition is given in [3, Algorithm 2.2.2.], and some analyses of this method are given in [3, 4, 5]. Another perspective is that
it is a special instance of [4, Algorithm 1] in which μh=0.
# Arguments
- `y0`: initial point; must be in the domain of h.
- `f=Zero()`: smooth objective term.
- `h=Zero()`: proximable objective term.
- `μf=0` : strong convexity constant of f (see above).
- `Lf` : Lipschitz constant of ∇f (see above).
- `adaptive=false` : enables the use of adaptive stepsize selection.
# References
- [1] Nesterov, Y. (2013). Gradient methods for minimizing composite functions. Mathematical Programming, 140(1), 125-161.
- [2] Beck, A., & Teboulle, M. (2009). A fast iterative shrinkage-thresholding algorithm for linear inverse problems. SIAM journal
on imaging sciences, 2(1), 183-202.
- [3] Kong, W. (2021). Accelerated Inexact First-Order Methods for Solving Nonconvex Composite Optimization Problems. arXiv
preprint arXiv:2104.09685.
- [4] Kong, W., Melo, J. G., & Monteiro, R. D. (2021). FISTA and Extensions - Review and New Insights. arXiv preprint
arXiv:2107.01267.
- [5] Florea, M. I. (2018). Constructing Accelerated Algorithms for Large-scale Optimization-Framework, Algorithms, and
Applications.
"""

@Base.kwdef struct SFISTAIteration{R,C<:Union{R,Complex{R}},Tx<:AbstractArray{C},Tf,Th}
y0::Tx
f::Tf = Zero()
h::Th = Zero()
Lf::R
μf::R = real(eltype(Lf))(0.0)
end

Base.IteratorSize(::Type{<:SFISTAIteration}) = Base.IsInfinite()

Base.@kwdef mutable struct SFISTAState{R, Tx}
λ::R # stepsize (mutable if iter.adaptive == true).
yPrev::Tx # previous main iterate.
y:: Tx = zero(yPrev) # main iterate.
xPrev::Tx = copy(yPrev) # previous auxiliary iterate.
x::Tx = zero(yPrev) # auxiliary iterate (see [3]).
xt::Tx = zero(yPrev) # prox center used to generate main iterate y.
τ::R = real(eltype(yPrev))(1.0) # helper variable (see [3]).
a::R = real(eltype(yPrev))(0.0) # helper variable (see [3]).
APrev::R = real(eltype(yPrev))(1.0) # previous A (helper variable).
A::R = real(eltype(yPrev))(0.0) # helper variable (see [3]).
gradf_xt::Tx = zero(yPrev) # array containing ∇f(xt).
end

function Base.iterate(iter::SFISTAIteration,
state::SFISTAState = SFISTAState=1/iter.Lf, yPrev=copy(iter.y0)))
# Set up helper variables.
state.τ = state.λ * (1 + iter.μf * state.APrev)
state.a = (state.τ + sqrt(state.τ ^ 2 + 4 * state.τ * state.APrev)) / 2
state.A = state.APrev + state.a
state.xt .= state.APrev / state.A * state.yPrev + state.a / state.A * state.xPrev
gradient!(state.gradf_xt, iter.f, state.xt)
λ2 = state.λ / (1 + state.λ * iter.μf)
# FISTA acceleration steps.
prox!(state.y, iter.h, state.xt - λ2 * state.gradf_xt, λ2)
state.x .= state.xPrev + state.a / (1 + state.A * iter.μf) * ((state.y - state.xt) / state.λ + iter.μf * (state.y - state.xPrev))
# Update state variables.
state.yPrev .= state.y
state.xPrev .= state.x
state.APrev = state.A
return state, state
end

## Solver.

struct SFISTA{R, K}
maxit::Int
tol::R
termination_type::String
verbose::Bool
freq::Int
kwargs::K
end

# Different stopping conditions (sc). Returns the current residual value and whether or not a stopping condition holds.
function check_sc(state::SFISTAState, iter::SFISTAIteration, tol, termination_type)
if termination_type == "AIPP"
# AIPP-style termination [4]. The main inclusion is: r ∈ ∂_η(f + h)(y).
r = (iter.y0 - state.x) / state.A
η = (norm(iter.y0 - state.y) ^ 2 - norm(state.x - state.y) ^ 2) / (2 * state.A)
res = (norm(r) ^ 2 + max(η, 0.0)) / max(norm(iter.y0 - state.y + r) ^ 2, 1e-16)
else
# Classic (approximate) first-order stationary point [4]. The main inclusion is: r ∈ ∇f(y) + ∂h(y).
λ2 = state.λ / (1 + state.λ * iter.μf)
gradf_y, = gradient(iter.f, state.y)
r = gradf_y - state.gradf_xt + (state.xt - state.y) / λ2
res = norm(r)
end
return res, (res <= tol || res tol)
end

# Functor ('function-like object') for the above type.
function (solver::SFISTA)(y0; kwargs...)
raw_iter = SFISTAIteration(; y0=y0, solver.kwargs..., kwargs...)
stop(state::SFISTAState) = check_sc(state, raw_iter, solver.tol, solver.termination_type)[2]
disp((it, state)) = @printf("%5d | %.3e\n", it, check_sc(state, raw_iter, solver.tol, solver.termination_type)[1])
iter = take(halt(raw_iter, stop), solver.maxit)
iter = enumerate(iter)
if solver.verbose
iter = tee(sample(iter, solver.freq), disp)
end
num_iters, state_final = loop(iter)
return state_final.y, num_iters
end

SFISTA(; maxit=1000, tol=1e-6, termination_type="", verbose=false, freq=(maxit < Inf ? Int(maxit/100) : 100), kwargs...) =
SFISTA(maxit, tol, termination_type, verbose, freq, kwargs)
15 changes: 15 additions & 0 deletions test/problems/test_lasso_small.jl
Original file line number Diff line number Diff line change
Expand Up @@ -194,4 +194,19 @@ using ProximalAlgorithms

end

@testset "SFISTA" begin

# SFISTA

x0 = zeros(T, n)
x0_backup = copy(x0)
solver = ProximalAlgorithms.SFISTA(tol = 10 * TOL)
y, it = solver(x0, f = f2, h = g, Lf = opnorm(A)^2)
@test eltype(y) == T
@test norm(y - x_star, Inf) <= 10 * TOL
@test it < 100
@test x0 == x0_backup

end

end
47 changes: 47 additions & 0 deletions test/problems/test_lasso_small_strongly_convex.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
using LinearAlgebra
using Test
using Random

using ProximalOperators
using ProximalAlgorithms

@testset "Lasso small (strongly convex, $T)" for T in [Float32, Float64]

Random.seed!(777)
dim = 5
μf = T(1)
Lf = T(10)

x_star = convert(Vector{T}, 1.5 * rand(T, dim) .- 0.5)

lam = (μf + Lf) / 2
@test typeof(lam) == T

D = Diagonal(sqrt(μf) .+ (sqrt(Lf) - sqrt(μf)) * rand(T, dim))
D[1] = sqrt(μf)
D[end] = sqrt(Lf)
Q = qr(rand(T, (dim, dim))).Q
A = Q * D * Q'
b = A * x_star + lam * inv(A') * sign.(x_star)

f = LeastSquares(A, b)
h = NormL1(lam)

TOL = T(1e-4)

@testset "SFISTA" begin

# SFISTA

x0 = A \ b
x0_backup = copy(x0)
solver = ProximalAlgorithms.SFISTA(tol = TOL)
y, it = solver(x0, f = f, h = h, Lf = Lf, μf = μf)
@test eltype(y) == T
@test norm(y - x_star) <= TOL
@test it < 200
@test x0 == x0_backup

end

end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ include("accel/noaccel.jl")
include("problems/test_equivalence.jl")
include("problems/test_elasticnet.jl")
include("problems/test_lasso_small.jl")
include("problems/test_lasso_small_strongly_convex.jl")
include("problems/test_lasso_small_v_split.jl")
include("problems/test_lasso_small_h_split.jl")
include("problems/test_linear_programs.jl")
Expand Down

0 comments on commit cf6f7e9

Please sign in to comment.