Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add equivalence test DRS/DRLS #46

Merged
merged 1 commit into from
Mar 13, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,14 @@ Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
ProximalOperators = "a725b495-10eb-56fe-b38b-717eba820537"

[compat]
julia = "1.0.0"
ProximalOperators = "0.14"
julia = "1.0.0"

[extras]
AbstractOperators = "d9c5613a-d543-52d8-9afd-8f241a8c3f1c"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
AbstractOperators = "d9c5613a-d543-52d8-9afd-8f241a8c3f1c"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Random", "Test", "RecursiveArrayTools", "AbstractOperators"]
1 change: 1 addition & 0 deletions src/ProximalAlgorithms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ include("accel/lbfgs.jl")
include("accel/anderson.jl")
include("accel/nesterov.jl")
include("accel/broyden.jl")
include("accel/noaccel.jl")

# algorithms

Expand Down
14 changes: 14 additions & 0 deletions src/accel/noaccel.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
struct Noaccel end

function update!(_::Noaccel, _, _) end

import Base: *

function (*)(L::Noaccel, v)
w = similar(v)
mul!(w, L, v)
end

import LinearAlgebra: mul!

mul!(d::T, _::Noaccel, v::T) where {T} = copyto!(d, v)
16 changes: 8 additions & 8 deletions src/algorithms/drls.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@ struct DRLS_iterable{R,C<:Union{R,Complex{R}},Tx<:AbstractArray{C},Tf,Tg,TH}
f::Tf
g::Tg
x0::Tx
lambda::R
gamma::R
lambda::R
c::R
N::Int
max_backtracks::Int
H::TH
end

Expand Down Expand Up @@ -85,7 +85,7 @@ function Base.iterate(iter::DRLS_iterable{R}, state::DRLS_state) where {R}
state.tau = R(1)
state.x .= state.x_d

for k = 1:iter.N
for k = 1:iter.max_backtracks
state.f_u = prox!(state.u, iter.f, state.x, iter.gamma)
state.w .= 2 .* state.u .- state.x
state.g_v = prox!(state.v, iter.g, state.w, iter.gamma)
Expand Down Expand Up @@ -117,7 +117,7 @@ struct DRLS{R}
beta::R
gamma::Maybe{R}
lambda::R
N::Int
max_backtracks::Int
memory::Int
maxit::Int
tol::R
Expand All @@ -129,7 +129,7 @@ struct DRLS{R}
beta::R = R(0.5),
gamma::Maybe{R} = nothing,
lambda::R = R(1),
N::Int = 20,
max_backtracks::Int = 20,
memory::Int = 5,
maxit::Int = 1000,
tol::R = R(1e-8),
Expand All @@ -140,12 +140,12 @@ struct DRLS{R}
@assert 0 < beta < 1
@assert gamma === nothing || gamma > 0
@assert 0 < lambda < 2
@assert N > 0
@assert max_backtracks > 0
@assert memory >= 0
@assert maxit > 0
@assert tol > 0
@assert freq > 0
new(alpha, beta, gamma, lambda, N, memory, maxit, tol, verbose, freq)
new(alpha, beta, gamma, lambda, max_backtracks, memory, maxit, tol, verbose, freq)
end
end

Expand Down Expand Up @@ -184,7 +184,7 @@ function (solver::DRLS{R})(
c = solver.beta * C_gamma_lambda

iter =
DRLS_iterable(f, g, x0, solver.lambda, gamma, c, solver.N, LBFGS(x0, solver.memory))
DRLS_iterable(f, g, x0, gamma, solver.lambda, c, solver.max_backtracks, LBFGS(x0, solver.memory))
iter = take(halt(iter, stop), solver.maxit)
iter = enumerate(iter)
if solver.verbose
Expand Down
17 changes: 17 additions & 0 deletions test/accel/noaccel.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
using LinearAlgebra
using Test

using ProximalAlgorithms

@testset "Noaccel" begin
L = ProximalAlgorithms.Noaccel()

x = randn(10)
y = L*x

@test y == x

mul!(y, L, x)

@test y == x
end
33 changes: 33 additions & 0 deletions test/problems/test_drs_drls_equivalence.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
using LinearAlgebra
using Test

using ProximalOperators
using ProximalAlgorithms

@testset "DRS/DRLS equivalence ($T)" for T in [Float32, Float64]
A = T[
1.0 -2.0 3.0 -4.0 5.0
2.0 -1.0 0.0 -1.0 3.0
-1.0 0.0 4.0 -3.0 2.0
-1.0 -1.0 -1.0 1.0 3.0
]
b = T[1.0, 2.0, 3.0, 4.0]

m, n = size(A)

R = real(T)

lam = R(0.1) * norm(A' * b, Inf)

f = LeastSquares(A, b)
g = NormL1(lam)

x0 = zeros(R, n)

drs_iter = ProximalAlgorithms.DRS_iterable(f, g, x0, R(10) / opnorm(A)^2)
drls_iter = ProximalAlgorithms.DRLS_iterable(f, g, x0, R(10) / opnorm(A)^2, R(1), -R(Inf), 1, ProximalAlgorithms.Noaccel())

for (state_drs, state_drls) in Iterators.take(zip(drs_iter, drls_iter), 10)
@test isapprox(state_drs.x, state_drls.xbar)
end
end
12 changes: 7 additions & 5 deletions test/problems/test_lasso_small.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
@testset "Lasso small ($T)" for T in [Float32, Float64, ComplexF32, ComplexF64]
using ProximalOperators
using ProximalAlgorithms
using LinearAlgebra
using LinearAlgebra
using Test

using ProximalOperators
using ProximalAlgorithms

@testset "Lasso small ($T)" for T in [Float32, Float64, ComplexF32, ComplexF64]
A = T[
1.0 -2.0 3.0 -4.0 5.0
2.0 -1.0 0.0 -1.0 3.0
Expand Down Expand Up @@ -116,7 +118,7 @@

x0 = zeros(T, n)
solver =
ProximalAlgorithms.DouglasRachford{R}(gamma = R(10.0) / opnorm(A)^2, tol = TOL)
ProximalAlgorithms.DouglasRachford{R}(gamma = R(10) / opnorm(A)^2, tol = TOL)
y, z, it = solver(x0, f = f2, g = g)
@test eltype(y) == T
@test eltype(z) == T
Expand Down
2 changes: 1 addition & 1 deletion test/problems/test_verbose.jl
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@

x0 = zeros(T, n)
solver = ProximalAlgorithms.DouglasRachford{R}(
gamma = R(10.0) / opnorm(A)^2,
gamma = R(10) / opnorm(A)^2,
tol = TOL,
verbose = true,
)
Expand Down
2 changes: 2 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@ include("accel/lbfgs.jl")
include("accel/anderson.jl")
include("accel/nesterov.jl")
include("accel/broyden.jl")
include("accel/noaccel.jl")

include("problems/test_drs_drls_equivalence.jl")
include("problems/test_elasticnet.jl")
include("problems/test_lasso_small.jl")
include("problems/test_lasso_small_v_split.jl")
Expand Down