Skip to content

Commit

Permalink
Merge 73c25cd into 069de90
Browse files Browse the repository at this point in the history
  • Loading branch information
lostella authored Mar 13, 2021
2 parents 069de90 + 73c25cd commit 271e93d
Show file tree
Hide file tree
Showing 9 changed files with 86 additions and 17 deletions.
6 changes: 3 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,14 @@ Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
ProximalOperators = "a725b495-10eb-56fe-b38b-717eba820537"

[compat]
julia = "1.0.0"
ProximalOperators = "0.14"
julia = "1.0.0"

[extras]
AbstractOperators = "d9c5613a-d543-52d8-9afd-8f241a8c3f1c"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
AbstractOperators = "d9c5613a-d543-52d8-9afd-8f241a8c3f1c"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Random", "Test", "RecursiveArrayTools", "AbstractOperators"]
1 change: 1 addition & 0 deletions src/ProximalAlgorithms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ include("accel/lbfgs.jl")
include("accel/anderson.jl")
include("accel/nesterov.jl")
include("accel/broyden.jl")
include("accel/noaccel.jl")

# algorithms

Expand Down
14 changes: 14 additions & 0 deletions src/accel/noaccel.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
struct Noaccel end

function update!(_::Noaccel, _, _) end

import Base: *

function (*)(L::Noaccel, v)
w = similar(v)
mul!(w, L, v)
end

import LinearAlgebra: mul!

mul!(d::T, _::Noaccel, v::T) where {T} = copyto!(d, v)
16 changes: 8 additions & 8 deletions src/algorithms/drls.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@ struct DRLS_iterable{R,C<:Union{R,Complex{R}},Tx<:AbstractArray{C},Tf,Tg,TH}
f::Tf
g::Tg
x0::Tx
lambda::R
gamma::R
lambda::R
c::R
N::Int
max_backtracks::Int
H::TH
end

Expand Down Expand Up @@ -85,7 +85,7 @@ function Base.iterate(iter::DRLS_iterable{R}, state::DRLS_state) where {R}
state.tau = R(1)
state.x .= state.x_d

for k = 1:iter.N
for k = 1:iter.max_backtracks
state.f_u = prox!(state.u, iter.f, state.x, iter.gamma)
state.w .= 2 .* state.u .- state.x
state.g_v = prox!(state.v, iter.g, state.w, iter.gamma)
Expand Down Expand Up @@ -117,7 +117,7 @@ struct DRLS{R}
beta::R
gamma::Maybe{R}
lambda::R
N::Int
max_backtracks::Int
memory::Int
maxit::Int
tol::R
Expand All @@ -129,7 +129,7 @@ struct DRLS{R}
beta::R = R(0.5),
gamma::Maybe{R} = nothing,
lambda::R = R(1),
N::Int = 20,
max_backtracks::Int = 20,
memory::Int = 5,
maxit::Int = 1000,
tol::R = R(1e-8),
Expand All @@ -140,12 +140,12 @@ struct DRLS{R}
@assert 0 < beta < 1
@assert gamma === nothing || gamma > 0
@assert 0 < lambda < 2
@assert N > 0
@assert max_backtracks > 0
@assert memory >= 0
@assert maxit > 0
@assert tol > 0
@assert freq > 0
new(alpha, beta, gamma, lambda, N, memory, maxit, tol, verbose, freq)
new(alpha, beta, gamma, lambda, max_backtracks, memory, maxit, tol, verbose, freq)
end
end

Expand Down Expand Up @@ -184,7 +184,7 @@ function (solver::DRLS{R})(
c = solver.beta * C_gamma_lambda

iter =
DRLS_iterable(f, g, x0, solver.lambda, gamma, c, solver.N, LBFGS(x0, solver.memory))
DRLS_iterable(f, g, x0, gamma, solver.lambda, c, solver.max_backtracks, LBFGS(x0, solver.memory))
iter = take(halt(iter, stop), solver.maxit)
iter = enumerate(iter)
if solver.verbose
Expand Down
17 changes: 17 additions & 0 deletions test/accel/noaccel.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
using LinearAlgebra
using Test

using ProximalAlgorithms

@testset "Noaccel" begin
L = ProximalAlgorithms.Noaccel()

x = randn(10)
y = L*x

@test y == x

mul!(y, L, x)

@test y == x
end
33 changes: 33 additions & 0 deletions test/problems/test_drs_drls_equivalence.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
using LinearAlgebra
using Test

using ProximalOperators
using ProximalAlgorithms

@testset "DRS/DRLS equivalence ($T)" for T in [Float32, Float64]
A = T[
1.0 -2.0 3.0 -4.0 5.0
2.0 -1.0 0.0 -1.0 3.0
-1.0 0.0 4.0 -3.0 2.0
-1.0 -1.0 -1.0 1.0 3.0
]
b = T[1.0, 2.0, 3.0, 4.0]

m, n = size(A)

R = real(T)

lam = R(0.1) * norm(A' * b, Inf)

f = LeastSquares(A, b)
g = NormL1(lam)

x0 = zeros(R, n)

drs_iter = ProximalAlgorithms.DRS_iterable(f, g, x0, R(10) / opnorm(A)^2)
drls_iter = ProximalAlgorithms.DRLS_iterable(f, g, x0, R(10) / opnorm(A)^2, R(1), -R(Inf), 1, ProximalAlgorithms.Noaccel())

for (state_drs, state_drls) in Iterators.take(zip(drs_iter, drls_iter), 10)
@test isapprox(state_drs.x, state_drls.xbar)
end
end
12 changes: 7 additions & 5 deletions test/problems/test_lasso_small.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
@testset "Lasso small ($T)" for T in [Float32, Float64, ComplexF32, ComplexF64]
using ProximalOperators
using ProximalAlgorithms
using LinearAlgebra
using LinearAlgebra
using Test

using ProximalOperators
using ProximalAlgorithms

@testset "Lasso small ($T)" for T in [Float32, Float64, ComplexF32, ComplexF64]
A = T[
1.0 -2.0 3.0 -4.0 5.0
2.0 -1.0 0.0 -1.0 3.0
Expand Down Expand Up @@ -116,7 +118,7 @@

x0 = zeros(T, n)
solver =
ProximalAlgorithms.DouglasRachford{R}(gamma = R(10.0) / opnorm(A)^2, tol = TOL)
ProximalAlgorithms.DouglasRachford{R}(gamma = R(10) / opnorm(A)^2, tol = TOL)
y, z, it = solver(x0, f = f2, g = g)
@test eltype(y) == T
@test eltype(z) == T
Expand Down
2 changes: 1 addition & 1 deletion test/problems/test_verbose.jl
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@

x0 = zeros(T, n)
solver = ProximalAlgorithms.DouglasRachford{R}(
gamma = R(10.0) / opnorm(A)^2,
gamma = R(10) / opnorm(A)^2,
tol = TOL,
verbose = true,
)
Expand Down
2 changes: 2 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@ include("accel/lbfgs.jl")
include("accel/anderson.jl")
include("accel/nesterov.jl")
include("accel/broyden.jl")
include("accel/noaccel.jl")

include("problems/test_drs_drls_equivalence.jl")
include("problems/test_elasticnet.jl")
include("problems/test_lasso_small.jl")
include("problems/test_lasso_small_v_split.jl")
Expand Down

0 comments on commit 271e93d

Please sign in to comment.