-
Notifications
You must be signed in to change notification settings - Fork 21
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
6 changed files
with
202 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,129 @@ | ||
# An implementation of a FISTA-like method, where the smooth part of the objective function can be strongly convex. | ||
|
||
using Base.Iterators | ||
using ProximalAlgorithms.IterationTools | ||
using ProximalOperators: Zero | ||
using LinearAlgebra | ||
using Printf | ||
|
||
""" | ||
SFISTAIteration(; <keyword-arguments>) | ||
Instantiate the FISTA-like method in [3] for solving strongly-convex composite optimization problems of the form | ||
minimize f(x) + h(x), | ||
where h is proper closed convex and f is a continuously differentiable function that is μ-strongly convex and whose gradient is | ||
Lf-Lipschitz continuous. | ||
The scheme is based on Nesterov's accelerated gradient method [1, Eq. (4.9)] and Beck's method for the convex case [2]. Its full | ||
definition is given in [3, Algorithm 2.2.2.], and some analyses of this method are given in [3, 4, 5]. Another perspective is that | ||
it is a special instance of [4, Algorithm 1] in which μh=0. | ||
# Arguments | ||
- `y0`: initial point; must be in the domain of h. | ||
- `f=Zero()`: smooth objective term. | ||
- `h=Zero()`: proximable objective term. | ||
- `μf=0` : strong convexity constant of f (see above). | ||
- `Lf` : Lipschitz constant of ∇f (see above). | ||
- `adaptive=false` : enables the use of adaptive stepsize selection. | ||
# References | ||
- [1] Nesterov, Y. (2013). Gradient methods for minimizing composite functions. Mathematical Programming, 140(1), 125-161. | ||
- [2] Beck, A., & Teboulle, M. (2009). A fast iterative shrinkage-thresholding algorithm for linear inverse problems. SIAM journal | ||
on imaging sciences, 2(1), 183-202. | ||
- [3] Kong, W. (2021). Accelerated Inexact First-Order Methods for Solving Nonconvex Composite Optimization Problems. arXiv | ||
preprint arXiv:2104.09685. | ||
- [4] Kong, W., Melo, J. G., & Monteiro, R. D. (2021). FISTA and Extensions - Review and New Insights. arXiv preprint | ||
arXiv:2107.01267. | ||
- [5] Florea, M. I. (2018). Constructing Accelerated Algorithms for Large-scale Optimization-Framework, Algorithms, and | ||
Applications. | ||
""" | ||
|
||
@Base.kwdef struct SFISTAIteration{R,C<:Union{R,Complex{R}},Tx<:AbstractArray{C},Tf,Th} | ||
y0::Tx | ||
f::Tf = Zero() | ||
h::Th = Zero() | ||
Lf::R | ||
μf::R = real(eltype(Lf))(0.0) | ||
end | ||
|
||
Base.IteratorSize(::Type{<:SFISTAIteration}) = Base.IsInfinite() | ||
|
||
Base.@kwdef mutable struct SFISTAState{R, Tx} | ||
λ::R # stepsize (mutable if iter.adaptive == true). | ||
yPrev::Tx # previous main iterate. | ||
y:: Tx = zero(yPrev) # main iterate. | ||
xPrev::Tx = copy(yPrev) # previous auxiliary iterate. | ||
x::Tx = zero(yPrev) # auxiliary iterate (see [3]). | ||
xt::Tx = zero(yPrev) # prox center used to generate main iterate y. | ||
τ::R = real(eltype(yPrev))(1.0) # helper variable (see [3]). | ||
a::R = real(eltype(yPrev))(0.0) # helper variable (see [3]). | ||
APrev::R = real(eltype(yPrev))(1.0) # previous A (helper variable). | ||
A::R = real(eltype(yPrev))(0.0) # helper variable (see [3]). | ||
gradf_xt::Tx = zero(yPrev) # array containing ∇f(xt). | ||
end | ||
|
||
function Base.iterate(iter::SFISTAIteration, | ||
state::SFISTAState = SFISTAState(λ=1/iter.Lf, yPrev=copy(iter.y0))) | ||
# Set up helper variables. | ||
state.τ = state.λ * (1 + iter.μf * state.APrev) | ||
state.a = (state.τ + sqrt(state.τ ^ 2 + 4 * state.τ * state.APrev)) / 2 | ||
state.A = state.APrev + state.a | ||
state.xt .= state.APrev / state.A * state.yPrev + state.a / state.A * state.xPrev | ||
gradient!(state.gradf_xt, iter.f, state.xt) | ||
λ2 = state.λ / (1 + state.λ * iter.μf) | ||
# FISTA acceleration steps. | ||
prox!(state.y, iter.h, state.xt - λ2 * state.gradf_xt, λ2) | ||
state.x .= state.xPrev + state.a / (1 + state.A * iter.μf) * ((state.y - state.xt) / state.λ + iter.μf * (state.y - state.xPrev)) | ||
# Update state variables. | ||
state.yPrev .= state.y | ||
state.xPrev .= state.x | ||
state.APrev = state.A | ||
return state, state | ||
end | ||
|
||
## Solver. | ||
|
||
struct SFISTA{R, K} | ||
maxit::Int | ||
tol::R | ||
termination_type::String | ||
verbose::Bool | ||
freq::Int | ||
kwargs::K | ||
end | ||
|
||
# Different stopping conditions (sc). Returns the current residual value and whether or not a stopping condition holds. | ||
function check_sc(state::SFISTAState, iter::SFISTAIteration, tol, termination_type) | ||
if termination_type == "AIPP" | ||
# AIPP-style termination [4]. The main inclusion is: r ∈ ∂_η(f + h)(y). | ||
r = (iter.y0 - state.x) / state.A | ||
η = (norm(iter.y0 - state.y) ^ 2 - norm(state.x - state.y) ^ 2) / (2 * state.A) | ||
res = (norm(r) ^ 2 + max(η, 0.0)) / max(norm(iter.y0 - state.y + r) ^ 2, 1e-16) | ||
else | ||
# Classic (approximate) first-order stationary point [4]. The main inclusion is: r ∈ ∇f(y) + ∂h(y). | ||
λ2 = state.λ / (1 + state.λ * iter.μf) | ||
gradf_y, = gradient(iter.f, state.y) | ||
r = gradf_y - state.gradf_xt + (state.xt - state.y) / λ2 | ||
res = norm(r) | ||
end | ||
return res, (res <= tol || res ≈ tol) | ||
end | ||
|
||
# Functor ('function-like object') for the above type. | ||
function (solver::SFISTA)(y0; kwargs...) | ||
raw_iter = SFISTAIteration(; y0=y0, solver.kwargs..., kwargs...) | ||
stop(state::SFISTAState) = check_sc(state, raw_iter, solver.tol, solver.termination_type)[2] | ||
disp((it, state)) = @printf("%5d | %.3e\n", it, check_sc(state, raw_iter, solver.tol, solver.termination_type)[1]) | ||
iter = take(halt(raw_iter, stop), solver.maxit) | ||
iter = enumerate(iter) | ||
if solver.verbose | ||
iter = tee(sample(iter, solver.freq), disp) | ||
end | ||
num_iters, state_final = loop(iter) | ||
return state_final.y, num_iters | ||
end | ||
|
||
SFISTA(; maxit=1000, tol=1e-6, termination_type="", verbose=false, freq=(maxit < Inf ? Int(maxit/100) : 100), kwargs...) = | ||
SFISTA(maxit, tol, termination_type, verbose, freq, kwargs) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
using LinearAlgebra | ||
using Test | ||
using Random | ||
|
||
using ProximalOperators | ||
using ProximalAlgorithms | ||
|
||
@testset "Lasso small (strongly convex, $T)" for T in [Float32, Float64] | ||
|
||
Random.seed!(777) | ||
dim = 5 | ||
μf = T(1) | ||
Lf = T(10) | ||
|
||
x_star = convert(Vector{T}, 1.5 * rand(T, dim) .- 0.5) | ||
|
||
lam = (μf + Lf) / 2 | ||
@test typeof(lam) == T | ||
|
||
D = Diagonal(sqrt(μf) .+ (sqrt(Lf) - sqrt(μf)) * rand(T, dim)) | ||
D[1] = sqrt(μf) | ||
D[end] = sqrt(Lf) | ||
Q = qr(rand(T, (dim, dim))).Q | ||
A = Q * D * Q' | ||
b = A * x_star + lam * inv(A') * sign.(x_star) | ||
|
||
f = LeastSquares(A, b) | ||
h = NormL1(lam) | ||
|
||
TOL = T(1e-4) | ||
|
||
@testset "SFISTA" begin | ||
|
||
# SFISTA | ||
|
||
x0 = A \ b | ||
x0_backup = copy(x0) | ||
solver = ProximalAlgorithms.SFISTA(tol = TOL) | ||
y, it = solver(x0, f = f, h = h, Lf = Lf, μf = μf) | ||
@test eltype(y) == T | ||
@test norm(y - x_star) <= TOL | ||
@test it < 200 | ||
@test x0 == x0_backup | ||
|
||
end | ||
|
||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters