diff --git a/benchmark/benchmarks.jl b/benchmark/benchmarks.jl index 27301c9..9ffd17b 100644 --- a/benchmark/benchmarks.jl +++ b/benchmark/benchmarks.jl @@ -55,7 +55,7 @@ for T in [Float64] f = LeastSquares($A, $b) g = NormL1($lam) end - + SUITE[k]["AFBA-1"] = @benchmarkable solver(x0, y0, f=f, g=g, beta_f=beta_f) setup=begin beta_f = opnorm($A)^2 solver = ProximalAlgorithms.AFBA(theta=$R(1), mu=$R(1), tol=$R(1e-6)) @@ -73,4 +73,12 @@ for T in [Float64] h = Translate(SqrNormL2(), -$b) g = NormL1($lam) end + + SUITE[k]["SFISTA"] = @benchmarkable solver(y0, f=f, Lf=Lf, h=h) setup=begin + solver = ProximalAlgorithms.SFISTA(tol=$R(1e-3)) + y0 = zeros($T, size($A, 2)) + f = LeastSquares($A, $b) + h = NormL1($lam) + Lf = opnorm($A)^2 + end end diff --git a/src/ProximalAlgorithms.jl b/src/ProximalAlgorithms.jl index 41e589c..63b1a5b 100644 --- a/src/ProximalAlgorithms.jl +++ b/src/ProximalAlgorithms.jl @@ -29,5 +29,6 @@ include("algorithms/drls.jl") include("algorithms/primaldual.jl") include("algorithms/davisyin.jl") include("algorithms/lilin.jl") +include("algorithms/fista.jl") end # module diff --git a/src/algorithms/fista.jl b/src/algorithms/fista.jl new file mode 100644 index 0000000..5d3250a --- /dev/null +++ b/src/algorithms/fista.jl @@ -0,0 +1,129 @@ +# An implementation of a FISTA-like method, where the smooth part of the objective function can be strongly convex. + +using Base.Iterators +using ProximalAlgorithms.IterationTools +using ProximalOperators: Zero +using LinearAlgebra +using Printf + +""" + SFISTAIteration(; ) + +Instantiate the FISTA-like method in [3] for solving strongly-convex composite optimization problems of the form + + minimize f(x) + h(x), + +where h is proper closed convex and f is a continuously differentiable function that is μ-strongly convex and whose gradient is +Lf-Lipschitz continuous. + +The scheme is based on Nesterov's accelerated gradient method [1, Eq. (4.9)] and Beck's method for the convex case [2]. Its full +definition is given in [3, Algorithm 2.2.2.], and some analyses of this method are given in [3, 4, 5]. Another perspective is that +it is a special instance of [4, Algorithm 1] in which μh=0. + +# Arguments +- `y0`: initial point; must be in the domain of h. +- `f=Zero()`: smooth objective term. +- `h=Zero()`: proximable objective term. +- `μf=0` : strong convexity constant of f (see above). +- `Lf` : Lipschitz constant of ∇f (see above). +- `adaptive=false` : enables the use of adaptive stepsize selection. + +# References +- [1] Nesterov, Y. (2013). Gradient methods for minimizing composite functions. Mathematical Programming, 140(1), 125-161. +- [2] Beck, A., & Teboulle, M. (2009). A fast iterative shrinkage-thresholding algorithm for linear inverse problems. SIAM journal +on imaging sciences, 2(1), 183-202. +- [3] Kong, W. (2021). Accelerated Inexact First-Order Methods for Solving Nonconvex Composite Optimization Problems. arXiv +preprint arXiv:2104.09685. +- [4] Kong, W., Melo, J. G., & Monteiro, R. D. (2021). FISTA and Extensions - Review and New Insights. arXiv preprint +arXiv:2107.01267. +- [5] Florea, M. I. (2018). Constructing Accelerated Algorithms for Large-scale Optimization-Framework, Algorithms, and +Applications. +""" + +@Base.kwdef struct SFISTAIteration{R,C<:Union{R,Complex{R}},Tx<:AbstractArray{C},Tf,Th} + y0::Tx + f::Tf = Zero() + h::Th = Zero() + Lf::R + μf::R = real(eltype(Lf))(0.0) +end + +Base.IteratorSize(::Type{<:SFISTAIteration}) = Base.IsInfinite() + +Base.@kwdef mutable struct SFISTAState{R, Tx} + λ::R # stepsize (mutable if iter.adaptive == true). + yPrev::Tx # previous main iterate. + y:: Tx = zero(yPrev) # main iterate. + xPrev::Tx = copy(yPrev) # previous auxiliary iterate. + x::Tx = zero(yPrev) # auxiliary iterate (see [3]). + xt::Tx = zero(yPrev) # prox center used to generate main iterate y. + τ::R = real(eltype(yPrev))(1.0) # helper variable (see [3]). + a::R = real(eltype(yPrev))(0.0) # helper variable (see [3]). + APrev::R = real(eltype(yPrev))(1.0) # previous A (helper variable). + A::R = real(eltype(yPrev))(0.0) # helper variable (see [3]). + gradf_xt::Tx = zero(yPrev) # array containing ∇f(xt). +end + +function Base.iterate(iter::SFISTAIteration, + state::SFISTAState = SFISTAState(λ=1/iter.Lf, yPrev=copy(iter.y0))) + # Set up helper variables. + state.τ = state.λ * (1 + iter.μf * state.APrev) + state.a = (state.τ + sqrt(state.τ ^ 2 + 4 * state.τ * state.APrev)) / 2 + state.A = state.APrev + state.a + state.xt .= state.APrev / state.A * state.yPrev + state.a / state.A * state.xPrev + gradient!(state.gradf_xt, iter.f, state.xt) + λ2 = state.λ / (1 + state.λ * iter.μf) + # FISTA acceleration steps. + prox!(state.y, iter.h, state.xt - λ2 * state.gradf_xt, λ2) + state.x .= state.xPrev + state.a / (1 + state.A * iter.μf) * ((state.y - state.xt) / state.λ + iter.μf * (state.y - state.xPrev)) + # Update state variables. + state.yPrev .= state.y + state.xPrev .= state.x + state.APrev = state.A + return state, state +end + +## Solver. + +struct SFISTA{R, K} + maxit::Int + tol::R + termination_type::String + verbose::Bool + freq::Int + kwargs::K +end + +# Different stopping conditions (sc). Returns the current residual value and whether or not a stopping condition holds. +function check_sc(state::SFISTAState, iter::SFISTAIteration, tol, termination_type) + if termination_type == "AIPP" + # AIPP-style termination [4]. The main inclusion is: r ∈ ∂_η(f + h)(y). + r = (iter.y0 - state.x) / state.A + η = (norm(iter.y0 - state.y) ^ 2 - norm(state.x - state.y) ^ 2) / (2 * state.A) + res = (norm(r) ^ 2 + max(η, 0.0)) / max(norm(iter.y0 - state.y + r) ^ 2, 1e-16) + else + # Classic (approximate) first-order stationary point [4]. The main inclusion is: r ∈ ∇f(y) + ∂h(y). + λ2 = state.λ / (1 + state.λ * iter.μf) + gradf_y, = gradient(iter.f, state.y) + r = gradf_y - state.gradf_xt + (state.xt - state.y) / λ2 + res = norm(r) + end + return res, (res <= tol || res ≈ tol) +end + +# Functor ('function-like object') for the above type. +function (solver::SFISTA)(y0; kwargs...) + raw_iter = SFISTAIteration(; y0=y0, solver.kwargs..., kwargs...) + stop(state::SFISTAState) = check_sc(state, raw_iter, solver.tol, solver.termination_type)[2] + disp((it, state)) = @printf("%5d | %.3e\n", it, check_sc(state, raw_iter, solver.tol, solver.termination_type)[1]) + iter = take(halt(raw_iter, stop), solver.maxit) + iter = enumerate(iter) + if solver.verbose + iter = tee(sample(iter, solver.freq), disp) + end + num_iters, state_final = loop(iter) + return state_final.y, num_iters +end + +SFISTA(; maxit=1000, tol=1e-6, termination_type="", verbose=false, freq=(maxit < Inf ? Int(maxit/100) : 100), kwargs...) = + SFISTA(maxit, tol, termination_type, verbose, freq, kwargs) diff --git a/test/problems/test_lasso_small.jl b/test/problems/test_lasso_small.jl index 1ce6c1f..9a63e1e 100644 --- a/test/problems/test_lasso_small.jl +++ b/test/problems/test_lasso_small.jl @@ -194,4 +194,19 @@ using ProximalAlgorithms end + @testset "SFISTA" begin + + # SFISTA + + x0 = zeros(T, n) + x0_backup = copy(x0) + solver = ProximalAlgorithms.SFISTA(tol = 10 * TOL) + y, it = solver(x0, f = f2, h = g, Lf = opnorm(A)^2) + @test eltype(y) == T + @test norm(y - x_star, Inf) <= 10 * TOL + @test it < 100 + @test x0 == x0_backup + + end + end diff --git a/test/problems/test_lasso_small_strongly_convex.jl b/test/problems/test_lasso_small_strongly_convex.jl new file mode 100644 index 0000000..68db52b --- /dev/null +++ b/test/problems/test_lasso_small_strongly_convex.jl @@ -0,0 +1,47 @@ +using LinearAlgebra +using Test +using Random + +using ProximalOperators +using ProximalAlgorithms + +@testset "Lasso small (strongly convex, $T)" for T in [Float32, Float64] + + Random.seed!(777) + dim = 5 + μf = T(1) + Lf = T(10) + + x_star = convert(Vector{T}, 1.5 * rand(T, dim) .- 0.5) + + lam = (μf + Lf) / 2 + @test typeof(lam) == T + + D = Diagonal(sqrt(μf) .+ (sqrt(Lf) - sqrt(μf)) * rand(T, dim)) + D[1] = sqrt(μf) + D[end] = sqrt(Lf) + Q = qr(rand(T, (dim, dim))).Q + A = Q * D * Q' + b = A * x_star + lam * inv(A') * sign.(x_star) + + f = LeastSquares(A, b) + h = NormL1(lam) + + TOL = T(1e-4) + + @testset "SFISTA" begin + + # SFISTA + + x0 = A \ b + x0_backup = copy(x0) + solver = ProximalAlgorithms.SFISTA(tol = TOL) + y, it = solver(x0, f = f, h = h, Lf = Lf, μf = μf) + @test eltype(y) == T + @test norm(y - x_star) <= TOL + @test it < 200 + @test x0 == x0_backup + + end + +end diff --git a/test/runtests.jl b/test/runtests.jl index 9335b15..d3ba29e 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -14,6 +14,7 @@ include("accel/noaccel.jl") include("problems/test_equivalence.jl") include("problems/test_elasticnet.jl") include("problems/test_lasso_small.jl") +include("problems/test_lasso_small_strongly_convex.jl") include("problems/test_lasso_small_v_split.jl") include("problems/test_lasso_small_h_split.jl") include("problems/test_linear_programs.jl")