diff --git a/Project.toml b/Project.toml index ae5ec04..e8ec092 100644 --- a/Project.toml +++ b/Project.toml @@ -8,14 +8,14 @@ Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" ProximalOperators = "a725b495-10eb-56fe-b38b-717eba820537" [compat] -julia = "1.0.0" ProximalOperators = "0.14" +julia = "1.0.0" [extras] +AbstractOperators = "d9c5613a-d543-52d8-9afd-8f241a8c3f1c" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" -Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd" -AbstractOperators = "d9c5613a-d543-52d8-9afd-8f241a8c3f1c" +Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] test = ["Random", "Test", "RecursiveArrayTools", "AbstractOperators"] diff --git a/src/ProximalAlgorithms.jl b/src/ProximalAlgorithms.jl index 779251a..41e589c 100644 --- a/src/ProximalAlgorithms.jl +++ b/src/ProximalAlgorithms.jl @@ -17,6 +17,7 @@ include("accel/lbfgs.jl") include("accel/anderson.jl") include("accel/nesterov.jl") include("accel/broyden.jl") +include("accel/noaccel.jl") # algorithms diff --git a/src/accel/noaccel.jl b/src/accel/noaccel.jl new file mode 100644 index 0000000..4a311fc --- /dev/null +++ b/src/accel/noaccel.jl @@ -0,0 +1,14 @@ +struct Noaccel end + +function update!(_::Noaccel, _, _) end + +import Base: * + +function (*)(L::Noaccel, v) + w = similar(v) + mul!(w, L, v) +end + +import LinearAlgebra: mul! + +mul!(d::T, _::Noaccel, v::T) where {T} = copyto!(d, v) diff --git a/src/algorithms/drls.jl b/src/algorithms/drls.jl index 98a69ca..b87ad36 100644 --- a/src/algorithms/drls.jl +++ b/src/algorithms/drls.jl @@ -14,10 +14,10 @@ struct DRLS_iterable{R,C<:Union{R,Complex{R}},Tx<:AbstractArray{C},Tf,Tg,TH} f::Tf g::Tg x0::Tx - lambda::R gamma::R + lambda::R c::R - N::Int + max_backtracks::Int H::TH end @@ -85,7 +85,7 @@ function Base.iterate(iter::DRLS_iterable{R}, state::DRLS_state) where {R} state.tau = R(1) state.x .= state.x_d - for k = 1:iter.N + for k = 1:iter.max_backtracks state.f_u = prox!(state.u, iter.f, state.x, iter.gamma) state.w .= 2 .* state.u .- state.x state.g_v = prox!(state.v, iter.g, state.w, iter.gamma) @@ -117,7 +117,7 @@ struct DRLS{R} beta::R gamma::Maybe{R} lambda::R - N::Int + max_backtracks::Int memory::Int maxit::Int tol::R @@ -129,7 +129,7 @@ struct DRLS{R} beta::R = R(0.5), gamma::Maybe{R} = nothing, lambda::R = R(1), - N::Int = 20, + max_backtracks::Int = 20, memory::Int = 5, maxit::Int = 1000, tol::R = R(1e-8), @@ -140,12 +140,12 @@ struct DRLS{R} @assert 0 < beta < 1 @assert gamma === nothing || gamma > 0 @assert 0 < lambda < 2 - @assert N > 0 + @assert max_backtracks > 0 @assert memory >= 0 @assert maxit > 0 @assert tol > 0 @assert freq > 0 - new(alpha, beta, gamma, lambda, N, memory, maxit, tol, verbose, freq) + new(alpha, beta, gamma, lambda, max_backtracks, memory, maxit, tol, verbose, freq) end end @@ -184,7 +184,7 @@ function (solver::DRLS{R})( c = solver.beta * C_gamma_lambda iter = - DRLS_iterable(f, g, x0, solver.lambda, gamma, c, solver.N, LBFGS(x0, solver.memory)) + DRLS_iterable(f, g, x0, gamma, solver.lambda, c, solver.max_backtracks, LBFGS(x0, solver.memory)) iter = take(halt(iter, stop), solver.maxit) iter = enumerate(iter) if solver.verbose diff --git a/test/accel/noaccel.jl b/test/accel/noaccel.jl new file mode 100644 index 0000000..2148800 --- /dev/null +++ b/test/accel/noaccel.jl @@ -0,0 +1,17 @@ +using LinearAlgebra +using Test + +using ProximalAlgorithms + +@testset "Noaccel" begin + L = ProximalAlgorithms.Noaccel() + + x = randn(10) + y = L*x + + @test y == x + + mul!(y, L, x) + + @test y == x +end diff --git a/test/problems/test_drs_drls_equivalence.jl b/test/problems/test_drs_drls_equivalence.jl new file mode 100644 index 0000000..5dcbe8f --- /dev/null +++ b/test/problems/test_drs_drls_equivalence.jl @@ -0,0 +1,33 @@ +using LinearAlgebra +using Test + +using ProximalOperators +using ProximalAlgorithms + +@testset "DRS/DRLS equivalence ($T)" for T in [Float32, Float64] + A = T[ + 1.0 -2.0 3.0 -4.0 5.0 + 2.0 -1.0 0.0 -1.0 3.0 + -1.0 0.0 4.0 -3.0 2.0 + -1.0 -1.0 -1.0 1.0 3.0 + ] + b = T[1.0, 2.0, 3.0, 4.0] + + m, n = size(A) + + R = real(T) + + lam = R(0.1) * norm(A' * b, Inf) + + f = LeastSquares(A, b) + g = NormL1(lam) + + x0 = zeros(R, n) + + drs_iter = ProximalAlgorithms.DRS_iterable(f, g, x0, R(10) / opnorm(A)^2) + drls_iter = ProximalAlgorithms.DRLS_iterable(f, g, x0, R(10) / opnorm(A)^2, R(1), -R(Inf), 1, ProximalAlgorithms.Noaccel()) + + for (state_drs, state_drls) in Iterators.take(zip(drs_iter, drls_iter), 10) + @test isapprox(state_drs.x, state_drls.xbar) + end +end diff --git a/test/problems/test_lasso_small.jl b/test/problems/test_lasso_small.jl index d6361e0..1086ad9 100644 --- a/test/problems/test_lasso_small.jl +++ b/test/problems/test_lasso_small.jl @@ -1,8 +1,10 @@ -@testset "Lasso small ($T)" for T in [Float32, Float64, ComplexF32, ComplexF64] - using ProximalOperators - using ProximalAlgorithms - using LinearAlgebra +using LinearAlgebra +using Test + +using ProximalOperators +using ProximalAlgorithms +@testset "Lasso small ($T)" for T in [Float32, Float64, ComplexF32, ComplexF64] A = T[ 1.0 -2.0 3.0 -4.0 5.0 2.0 -1.0 0.0 -1.0 3.0 @@ -116,7 +118,7 @@ x0 = zeros(T, n) solver = - ProximalAlgorithms.DouglasRachford{R}(gamma = R(10.0) / opnorm(A)^2, tol = TOL) + ProximalAlgorithms.DouglasRachford{R}(gamma = R(10) / opnorm(A)^2, tol = TOL) y, z, it = solver(x0, f = f2, g = g) @test eltype(y) == T @test eltype(z) == T diff --git a/test/problems/test_verbose.jl b/test/problems/test_verbose.jl index 7b6c48d..423c7dc 100644 --- a/test/problems/test_verbose.jl +++ b/test/problems/test_verbose.jl @@ -125,7 +125,7 @@ x0 = zeros(T, n) solver = ProximalAlgorithms.DouglasRachford{R}( - gamma = R(10.0) / opnorm(A)^2, + gamma = R(10) / opnorm(A)^2, tol = TOL, verbose = true, ) diff --git a/test/runtests.jl b/test/runtests.jl index 65cdc84..e555cef 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -9,7 +9,9 @@ include("accel/lbfgs.jl") include("accel/anderson.jl") include("accel/nesterov.jl") include("accel/broyden.jl") +include("accel/noaccel.jl") +include("problems/test_drs_drls_equivalence.jl") include("problems/test_elasticnet.jl") include("problems/test_lasso_small.jl") include("problems/test_lasso_small_v_split.jl")