From 506015414dc4e4c3ad19c2c001fffda130de1504 Mon Sep 17 00:00:00 2001 From: Lorenzo Stella Date: Tue, 8 Feb 2022 21:42:17 +0100 Subject: [PATCH 1/6] tests pass --- Project.toml | 13 ++++++++- src/ProximalAlgorithms.jl | 23 ++-------------- src/algorithms/davis_yin.jl | 2 +- src/algorithms/douglas_rachford.jl | 2 +- src/algorithms/drls.jl | 35 +++++++++++-------------- src/algorithms/fast_forward_backward.jl | 2 +- src/algorithms/forward_backward.jl | 2 +- src/algorithms/li_lin.jl | 2 +- src/algorithms/panoc.jl | 6 ++--- src/algorithms/panocplus.jl | 2 +- src/algorithms/primal_dual.jl | 6 ++--- src/algorithms/sfista.jl | 2 +- src/algorithms/zerofpr.jl | 2 +- src/utilities/ad.jl | 12 +++------ src/utilities/conjugate.jl | 8 ------ test/Project.toml | 8 ------ test/definitions/arraypartition.jl | 30 +++++---------------- test/definitions/compose.jl | 13 +++++---- test/runtests.jl | 1 - test/utilities/test_ad.jl | 2 +- test/utilities/test_conjugate.jl | 30 --------------------- 21 files changed, 60 insertions(+), 143 deletions(-) delete mode 100644 src/utilities/conjugate.jl delete mode 100644 test/Project.toml delete mode 100644 test/utilities/test_conjugate.jl diff --git a/Project.toml b/Project.toml index 46bcc01..74d48a6 100644 --- a/Project.toml +++ b/Project.toml @@ -3,12 +3,23 @@ uuid = "140ffc9f-1907-541a-a177-7475e0a401e9" version = "0.5.0" [deps] +AbstractOperators = "d9c5613a-d543-52d8-9afd-8f241a8c3f1c" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" +ProximalCore = "dc4f5ac2-75d1-4f31-931e-60435d74994b" ProximalOperators = "a725b495-10eb-56fe-b38b-717eba820537" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] -ProximalOperators = "0.14" +ProximalCore = "0.1" Zygote = "0.6" julia = "1.2" + +[extras] +ProximalOperators = "a725b495-10eb-56fe-b38b-717eba820537" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd" +Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" + +[targets] +test = ["Random", "Test", "ProximalOperators", "RecursiveArrayTools"] diff --git a/src/ProximalAlgorithms.jl b/src/ProximalAlgorithms.jl index 40851fb..743c0ea 100644 --- a/src/ProximalAlgorithms.jl +++ b/src/ProximalAlgorithms.jl @@ -1,33 +1,14 @@ module ProximalAlgorithms -using ProximalOperators -import ProximalOperators: prox!, gradient! +using ProximalCore +using ProximalCore: prox, prox!, gradient, gradient! const RealOrComplex{R} = Union{R,Complex{R}} const Maybe{T} = Union{T,Nothing} -""" - prox!(y, f, x, gamma) - -Compute the proximal mapping of `f` at `x`, with stepsize `gamma`, and store the result in `y`. -Return the value of `f` at `y`. -""" -prox!(y, f, x, gamma) - -""" - gradient!(g, f, x) - -Compute the gradient of `f` at `x`, and stores it in `y`. Return the value of `f` at `x`. -""" -gradient!(y, f, x) - -# TODO move out -ProximalOperators.is_quadratic(::Any) = false - # various utilities include("utilities/ad.jl") -include("utilities/conjugate.jl") include("utilities/fb_tools.jl") include("utilities/iteration_tools.jl") diff --git a/src/algorithms/davis_yin.jl b/src/algorithms/davis_yin.jl index 7726ce6..93ad084 100644 --- a/src/algorithms/davis_yin.jl +++ b/src/algorithms/davis_yin.jl @@ -3,7 +3,7 @@ # pp. 829–858 (2017). using Printf -using ProximalOperators: Zero +using ProximalCore: Zero using LinearAlgebra using Printf diff --git a/src/algorithms/douglas_rachford.jl b/src/algorithms/douglas_rachford.jl index 9edce5a..06ae042 100644 --- a/src/algorithms/douglas_rachford.jl +++ b/src/algorithms/douglas_rachford.jl @@ -3,7 +3,7 @@ # Mathematical Programming, vol. 55, no. 1, pp. 293-318 (1989). using Base.Iterators -using ProximalOperators: Zero +using ProximalCore: Zero using LinearAlgebra using Printf diff --git a/src/algorithms/drls.jl b/src/algorithms/drls.jl index 580aa06..5f6097f 100644 --- a/src/algorithms/drls.jl +++ b/src/algorithms/drls.jl @@ -6,24 +6,20 @@ using Base.Iterators using ProximalAlgorithms.IterationTools -using ProximalOperators: Zero +using ProximalCore: Zero using LinearAlgebra using Printf -function drls_default_gamma(f, mf, Lf, alpha, lambda) +function drls_default_gamma(f::Tf, mf, Lf, alpha, lambda) where Tf if mf !== nothing && mf > 0 return 1 / (alpha * mf) end - if ProximalOperators.is_convex(f) - return alpha / Lf - else - return alpha * (2 - lambda) / (2 * Lf) - end + return ProximalCore.is_convex(Tf) ? alpha / Lf : alpha * (2 - lambda) / (2 * Lf) end -function drls_C(f, mf, Lf, gamma, lambda) +function drls_C(f::Tf, mf, Lf, gamma, lambda) where Tf a = mf === nothing || mf <= 0 ? gamma * Lf : 1 / (gamma * mf) - m = ProximalOperators.is_convex(f) ? max(a - lambda / 2, 0) : 1 + m = ProximalCore.is_convex(Tf) ? max(a - lambda / 2, 0) : 1 return (lambda / ((1 + a)^2) * ((2 - lambda) / 2 - a * m)) end @@ -121,7 +117,7 @@ update_direction_state!(::NesterovStyle, ::DRLSIteration, state::DRLSState) = re update_direction_state!(::NoAccelerationStyle, ::DRLSIteration, state::DRLSState) = return update_direction_state!(iter::DRLSIteration, state::DRLSState) = update_direction_state!(acceleration_style(typeof(iter.directions)), iter, state) -function Base.iterate(iter::DRLSIteration{R}, state::DRLSState) where R +function Base.iterate(iter::DRLSIteration{R, Tx, Tf}, state::DRLSState) where {R, Tx, Tf} DRE_curr = DRE(state) set_next_direction!(iter, state) @@ -150,16 +146,15 @@ function Base.iterate(iter::DRLSIteration{R}, state::DRLSState) where R state.tau = k == iter.max_backtracks ? R(0) : state.tau / 2 state.x .= state.tau .* state.x_d .+ (1 - state.tau) .* state.xbar_prev - if k == 1 && ProximalOperators.is_generalized_quadratic(iter.f) - copyto!(state.u1, state.u) - c = prox!(state.u0, iter.f, state.xbar_prev, iter.gamma) - state.temp_x1 .= state.xbar_prev .- state.x_d - state.temp_x2 .= state.xbar_prev .- state.u0 - b = real(dot(state.temp_x1, state.temp_x2)) / iter.gamma - a = state.f_u - b - c - end - - if ProximalOperators.is_generalized_quadratic(iter.f) + if ProximalCore.is_generalized_quadratic(Tf) + if k == 1 + copyto!(state.u1, state.u) + c = prox!(state.u0, iter.f, state.xbar_prev, iter.gamma) + state.temp_x1 .= state.xbar_prev .- state.x_d + state.temp_x2 .= state.xbar_prev .- state.u0 + b = real(dot(state.temp_x1, state.temp_x2)) / iter.gamma + a = state.f_u - b - c + end state.u .= state.tau .* state.u1 .+ (1 - state.tau) .* state.u0 state.f_u = a * state.tau ^ 2 + b * state.tau + c else diff --git a/src/algorithms/fast_forward_backward.jl b/src/algorithms/fast_forward_backward.jl index 1120f11..b797478 100644 --- a/src/algorithms/fast_forward_backward.jl +++ b/src/algorithms/fast_forward_backward.jl @@ -7,7 +7,7 @@ using Base.Iterators using ProximalAlgorithms.IterationTools -using ProximalOperators: Zero +using ProximalCore: Zero using LinearAlgebra using Printf diff --git a/src/algorithms/forward_backward.jl b/src/algorithms/forward_backward.jl index 5d4ad6f..6d7f43b 100644 --- a/src/algorithms/forward_backward.jl +++ b/src/algorithms/forward_backward.jl @@ -3,7 +3,7 @@ using Base.Iterators using ProximalAlgorithms.IterationTools -using ProximalOperators: Zero +using ProximalCore: Zero using LinearAlgebra using Printf diff --git a/src/algorithms/li_lin.jl b/src/algorithms/li_lin.jl index ea2e27c..cb572a8 100644 --- a/src/algorithms/li_lin.jl +++ b/src/algorithms/li_lin.jl @@ -3,7 +3,7 @@ using Base.Iterators using ProximalAlgorithms.IterationTools -using ProximalOperators: Zero +using ProximalCore: Zero using LinearAlgebra using Printf diff --git a/src/algorithms/panoc.jl b/src/algorithms/panoc.jl index 8844dd4..6765f46 100644 --- a/src/algorithms/panoc.jl +++ b/src/algorithms/panoc.jl @@ -4,7 +4,7 @@ using Base.Iterators using ProximalAlgorithms.IterationTools -using ProximalOperators: Zero +using ProximalCore: Zero using LinearAlgebra using Printf @@ -110,7 +110,7 @@ reset_direction_state!(::QuasiNewtonStyle, ::PANOCIteration, state::PANOCState) reset_direction_state!(::NoAccelerationStyle, ::PANOCIteration, state::PANOCState) = return reset_direction_state!(iter::PANOCIteration, state::PANOCState) = reset_direction_state!(acceleration_style(typeof(iter.directions)), iter, state) -function Base.iterate(iter::PANOCIteration{R}, state::PANOCState) where R +function Base.iterate(iter::PANOCIteration{R, Tx, Tf}, state::PANOCState) where {R, Tx, Tf} f_Az, a, b, c = R(Inf), R(Inf), R(Inf), R(Inf) f_Az_upp = if iter.adaptive == true @@ -177,7 +177,7 @@ function Base.iterate(iter::PANOCIteration{R}, state::PANOCState) where R state.x .= state.tau .* state.x_d .+ (1 - state.tau) .* state.z_curr state.Ax .= state.tau .* state.Ax_d .+ (1 - state.tau) .* state.Az - if ProximalOperators.is_quadratic(iter.f) + if ProximalCore.is_generalized_quadratic(Tf) # in case f is quadratic, we can compute its value and gradient # along a line using interpolation and linear combinations # this allows saving operations diff --git a/src/algorithms/panocplus.jl b/src/algorithms/panocplus.jl index 155bf3d..63b2acc 100644 --- a/src/algorithms/panocplus.jl +++ b/src/algorithms/panocplus.jl @@ -4,7 +4,7 @@ using Base.Iterators using ProximalAlgorithms.IterationTools -using ProximalOperators: Zero +using ProximalCore: Zero using LinearAlgebra using Printf diff --git a/src/algorithms/primal_dual.jl b/src/algorithms/primal_dual.jl index cff6a13..1ac83ba 100644 --- a/src/algorithms/primal_dual.jl +++ b/src/algorithms/primal_dual.jl @@ -22,7 +22,7 @@ using Base.Iterators using ProximalAlgorithms.IterationTools -using ProximalOperators: Zero +using ProximalCore: Zero, IndZero, convex_conjugate using LinearAlgebra using Printf @@ -175,13 +175,13 @@ function Base.iterate(iter::AFBAIteration, state::AFBAState = AFBAState(x=copy(i prox!(state.xbar, iter.g, state.temp_x, iter.gamma[1]) # perform ybar-update step - gradient!(state.gradl, Conjugate(iter.l), state.y) + gradient!(state.gradl, convex_conjugate(iter.l), state.y) state.temp_x .= iter.theta .* state.xbar .+ (1 - iter.theta) .* state.x mul!(state.temp_y, iter.L, state.temp_x) state.temp_y .-= state.gradl state.temp_y .*= iter.gamma[2] state.temp_y .+= state.y - prox!(state.ybar, Conjugate(iter.h), state.temp_y, iter.gamma[2]) + prox!(state.ybar, convex_conjugate(iter.h), state.temp_y, iter.gamma[2]) # the residues state.FPR_x .= state.xbar .- state.x diff --git a/src/algorithms/sfista.jl b/src/algorithms/sfista.jl index 9ce43f0..954309d 100644 --- a/src/algorithms/sfista.jl +++ b/src/algorithms/sfista.jl @@ -2,7 +2,7 @@ using Base.Iterators using ProximalAlgorithms.IterationTools -using ProximalOperators: Zero +using ProximalCore: Zero using LinearAlgebra using Printf diff --git a/src/algorithms/zerofpr.jl b/src/algorithms/zerofpr.jl index 4c98750..148a582 100644 --- a/src/algorithms/zerofpr.jl +++ b/src/algorithms/zerofpr.jl @@ -5,7 +5,7 @@ using Base.Iterators using ProximalAlgorithms.IterationTools -using ProximalOperators: Zero +using ProximalCore: Zero using LinearAlgebra using Printf diff --git a/src/utilities/ad.jl b/src/utilities/ad.jl index a25d1a8..f14c097 100644 --- a/src/utilities/ad.jl +++ b/src/utilities/ad.jl @@ -1,14 +1,8 @@ using Zygote: pullback -using ProximalOperators +using ProximalCore -function ProximalOperators.gradient(f, x) +function ProximalCore.gradient!(grad, f, x) fx, pb = pullback(f, x) - grad = pb(one(fx))[1] - return grad, fx -end - -function ProximalOperators.gradient!(grad, f, x) - y, fx = gradient(f, x) - grad .= y + grad .= pb(one(fx))[1] return fx end diff --git a/src/utilities/conjugate.jl b/src/utilities/conjugate.jl deleted file mode 100644 index d831d8a..0000000 --- a/src/utilities/conjugate.jl +++ /dev/null @@ -1,8 +0,0 @@ -using ProximalOperators -using ProximalOperators: Zero - -ProximalOperators.Conjugate(_::Zero) = IndZero() -ProximalOperators.Conjugate(_::IndZero) = Zero() -ProximalOperators.Conjugate(f::SqrNormL2) = SqrNormL2(1.0 / f.lambda) - -# TODO: Add other useful functions and calculus rules such as translation diff --git a/test/Project.toml b/test/Project.toml deleted file mode 100644 index f4cf9bb..0000000 --- a/test/Project.toml +++ /dev/null @@ -1,8 +0,0 @@ -[deps] -AbstractOperators = "d9c5613a-d543-52d8-9afd-8f241a8c3f1c" -LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" -Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" -ProximalOperators = "a725b495-10eb-56fe-b38b-717eba820537" -Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" -RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd" -Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" diff --git a/test/definitions/arraypartition.jl b/test/definitions/arraypartition.jl index 3602677..87dcb78 100644 --- a/test/definitions/arraypartition.jl +++ b/test/definitions/arraypartition.jl @@ -1,36 +1,20 @@ -import ProximalOperators +import ProximalCore import RecursiveArrayTools -@inline function ProximalOperators.prox( - h::ProximalOperators.ProximableFunction, - x::RecursiveArrayTools.ArrayPartition, - gamma..., -) +@inline function ProximalCore.prox(h, x::RecursiveArrayTools.ArrayPartition, gamma...) # unwrap - y, fy = ProximalOperators.prox(h, x.x, gamma...) + y, fy = ProximalCore.prox(h, x.x, gamma...) # wrap return RecursiveArrayTools.ArrayPartition(y), fy end -@inline function ProximalOperators.gradient( - h::ProximalOperators.ProximableFunction, - x::RecursiveArrayTools.ArrayPartition, -) +@inline function ProximalCore.gradient(h, x::RecursiveArrayTools.ArrayPartition) # unwrap - grad, fx = ProximalOperators.gradient(h, x.x) + grad, fx = ProximalCore.gradient(h, x.x) # wrap return RecursiveArrayTools.ArrayPartition(grad), fx end -@inline ProximalOperators.prox!( - y::RecursiveArrayTools.ArrayPartition, - h::ProximalOperators.ProximableFunction, - x::RecursiveArrayTools.ArrayPartition, - gamma..., -) = ProximalOperators.prox!(y.x, h, x.x, gamma...) +@inline ProximalCore.prox!(y::RecursiveArrayTools.ArrayPartition, h, x::RecursiveArrayTools.ArrayPartition, gamma...) = ProximalCore.prox!(y.x, h, x.x, gamma...) -@inline ProximalOperators.gradient!( - y::RecursiveArrayTools.ArrayPartition, - h::ProximalOperators.ProximableFunction, - x::RecursiveArrayTools.ArrayPartition, -) = ProximalOperators.gradient!(y.x, h, x.x) +@inline ProximalCore.gradient!(y::RecursiveArrayTools.ArrayPartition, h, x::RecursiveArrayTools.ArrayPartition) = ProximalCore.gradient!(y.x, h, x.x) diff --git a/test/definitions/compose.jl b/test/definitions/compose.jl index ee50b03..8e38c2a 100644 --- a/test/definitions/compose.jl +++ b/test/definitions/compose.jl @@ -1,8 +1,7 @@ -using ProximalOperators: ProximableFunction using RecursiveArrayTools: ArrayPartition -import ProximalOperators: gradient!, gradient +using ProximalCore -struct ComposeAffine <: ProximableFunction +struct ComposeAffine f A b @@ -20,11 +19,11 @@ function compose_affine_gradient!(y, g::ComposeAffine, x) return v end -gradient!(y, g::ComposeAffine, x) = compose_affine_gradient!(y, g, x) -gradient!(y::ArrayPartition, g::ComposeAffine, x::ArrayPartition) = compose_affine_gradient!(y, g, x) +ProximalCore.gradient!(y, g::ComposeAffine, x) = compose_affine_gradient!(y, g, x) +ProximalCore.gradient!(y::ArrayPartition, g::ComposeAffine, x::ArrayPartition) = compose_affine_gradient!(y, g, x) -function ProximalOperators.gradient(h::ComposeAffine, x::ArrayPartition) +function ProximalCore.gradient(h::ComposeAffine, x::ArrayPartition) grad_fx = similar(x) - fx = gradient!(grad_fx, h, x) + fx = ProximalCore.gradient!(grad_fx, h, x) return grad_fx, fx end diff --git a/test/runtests.jl b/test/runtests.jl index df8c672..82460bc 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -5,7 +5,6 @@ include("definitions/compose.jl") include("utilities/test_ad.jl") include("utilities/test_iteration_tools.jl") -include("utilities/test_conjugate.jl") include("utilities/test_fb_tools.jl") include("accel/test_lbfgs.jl") diff --git a/test/utilities/test_ad.jl b/test/utilities/test_ad.jl index d0d77c9..8abb327 100644 --- a/test/utilities/test_ad.jl +++ b/test/utilities/test_ad.jl @@ -18,7 +18,7 @@ using ProximalAlgorithms @testset "Gradient" begin x = randn(T, n) - gradfx, fx = gradient(f, x) + gradfx, fx = ProximalAlgorithms.gradient(f, x) @test eltype(gradfx) == T @test typeof(fx) == R @test gradfx ≈ A' * (A * x - b) diff --git a/test/utilities/test_conjugate.jl b/test/utilities/test_conjugate.jl deleted file mode 100644 index 1bce562..0000000 --- a/test/utilities/test_conjugate.jl +++ /dev/null @@ -1,30 +0,0 @@ -using Test - -@testset "Conjugate" begin - - using ProximalOperators - using ProximalOperators: Zero - using ProximalAlgorithms - - x = [1.0, -2.0, 3.0, -4.0, 5.0, -6.0] - - f = Conjugate(IndZero()) # = IndFree - grad_f_x, f_x = gradient(f, x) - @test iszero(grad_f_x) - @test iszero(f_x) - - g = Conjugate(Zero()) # = IndZero - prox_g_x, g_y = prox(g, x) - @test iszero(prox_g_x) - @test iszero(g_y) - - l = Conjugate(SqrNormL2()) - grad_l_x, l_x = gradient(l, x) - @test isequal(grad_l_x, x) - - lam = 2 - l = Conjugate(SqrNormL2(lam)) - grad_l_x, l_x = gradient(l, x) - @test isequal(grad_l_x, x / lam) - -end From 7323c7843f6ffbb838b243118008ffd32d04e983 Mon Sep 17 00:00:00 2001 From: Lorenzo Stella Date: Sat, 12 Feb 2022 22:25:17 +0100 Subject: [PATCH 2/6] add test env --- Project.toml | 11 ----------- test/Project.toml | 5 +++++ 2 files changed, 5 insertions(+), 11 deletions(-) create mode 100644 test/Project.toml diff --git a/Project.toml b/Project.toml index 74d48a6..22aea25 100644 --- a/Project.toml +++ b/Project.toml @@ -3,23 +3,12 @@ uuid = "140ffc9f-1907-541a-a177-7475e0a401e9" version = "0.5.0" [deps] -AbstractOperators = "d9c5613a-d543-52d8-9afd-8f241a8c3f1c" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" ProximalCore = "dc4f5ac2-75d1-4f31-931e-60435d74994b" -ProximalOperators = "a725b495-10eb-56fe-b38b-717eba820537" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] ProximalCore = "0.1" Zygote = "0.6" julia = "1.2" - -[extras] -ProximalOperators = "a725b495-10eb-56fe-b38b-717eba820537" -Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" -RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd" -Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" - -[targets] -test = ["Random", "Test", "ProximalOperators", "RecursiveArrayTools"] diff --git a/test/Project.toml b/test/Project.toml new file mode 100644 index 0000000..562f119 --- /dev/null +++ b/test/Project.toml @@ -0,0 +1,5 @@ +[deps] +ProximalOperators = "a725b495-10eb-56fe-b38b-717eba820537" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd" +Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" From c5cc064aec904829fe2befa1581ca26a2a9ebf3e Mon Sep 17 00:00:00 2001 From: Lorenzo Stella Date: Tue, 15 Feb 2022 23:18:27 +0100 Subject: [PATCH 3/6] update test env --- test/Project.toml | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/test/Project.toml b/test/Project.toml index 562f119..27ac831 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,4 +1,8 @@ [deps] +AbstractOperators = "d9c5613a-d543-52d8-9afd-8f241a8c3f1c" +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" +ProximalCore = "dc4f5ac2-75d1-4f31-931e-60435d74994b" ProximalOperators = "a725b495-10eb-56fe-b38b-717eba820537" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd" From 5e8b89eda2a0929785684cee740bf40fb589d33a Mon Sep 17 00:00:00 2001 From: Lorenzo Stella Date: Fri, 18 Feb 2022 22:59:53 +0100 Subject: [PATCH 4/6] fix docs --- docs/Project.toml | 1 + docs/make.jl | 5 +++-- docs/src/guide/custom_objectives.jl | 34 +++++++++++++++-------------- 3 files changed, 22 insertions(+), 18 deletions(-) diff --git a/docs/Project.toml b/docs/Project.toml index dee2318..a7ed7e2 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -5,4 +5,5 @@ HTTP = "cd3eb016-35fb-5094-929b-558a96fad6f3" Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306" Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" ProximalAlgorithms = "140ffc9f-1907-541a-a177-7475e0a401e9" +ProximalCore = "dc4f5ac2-75d1-4f31-931e-60435d74994b" ProximalOperators = "a725b495-10eb-56fe-b38b-717eba820537" diff --git a/docs/make.jl b/docs/make.jl index 86c94ed..ee1ad11 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -1,4 +1,5 @@ -using Documenter, DocumenterCitations, ProximalAlgorithms +using Documenter, DocumenterCitations +using ProximalAlgorithms, ProximalCore using Literate bib = CitationBibliography(joinpath(@__DIR__, "references.bib")) @@ -22,7 +23,7 @@ end makedocs( bib, - modules=[ProximalAlgorithms], + modules=[ProximalAlgorithms, ProximalCore], sitename="ProximalAlgorithms.jl", pages=[ "Home" => "index.md", diff --git a/docs/src/guide/custom_objectives.jl b/docs/src/guide/custom_objectives.jl index 0b4c361..09c83e0 100644 --- a/docs/src/guide/custom_objectives.jl +++ b/docs/src/guide/custom_objectives.jl @@ -1,25 +1,27 @@ # # [Custom objective terms](@id custom_terms) # -# ProximalAlgorithms relies on the first-order primitives implemented in [ProximalOperators](https://github.com/JuliaFirstOrder/ProximalOperators.jl): -# while a rich library of function types is provided there, one may need to formulate -# problems using custom objective terms. +# ProximalAlgorithms relies on the first-order primitives defined in [ProximalCore](https://github.com/JuliaFirstOrder/ProximalCore.jl). +# While a rich library of function types, implementing such primitives, is provided by [ProximalOperators](https://github.com/JuliaFirstOrder/ProximalOperators.jl), +# one may need to formulate problems using custom objective terms. # When that is the case, one only needs to implement the right first-order primitive, # ``\nabla f`` or ``\operatorname{prox}_{\gamma f}`` or both, for algorithms to be able # to work with ``f``. # -# Defining the proximal mapping for a custom function type requires adding a method for [`prox!`](@ref ProximalAlgorithms.prox!). +# Defining the proximal mapping for a custom function type requires adding a method for [`ProximalCore.prox!`](@ref). # -# To compute gradients, ProximalAlgorithms provides a fallback definition for [`gradient!`](@ref ProximalAlgorithms.gradient!), +# To compute gradients, ProximalAlgorithms provides a fallback definition for [`ProximalCore.gradient!`](@ref), # relying on [Zygote](https://github.com/FluxML/Zygote.jl) to use automatic differentiation. # Therefore, you can provide any (differentiable) Julia function wherever gradients need to be taken, # and everything will work out of the box. # # If however one would like to provide their own gradient implementation (e.g. for efficiency reasons), -# they can simply implement a method for [`gradient!`](@ref ProximalAlgorithms.gradient!). +# they can simply implement a method for [`ProximalCore.gradient!`](@ref). # # ```@docs -# ProximalAlgorithms.prox!(y, f, x, gamma) -# ProximalAlgorithms.gradient!(g, f, x) +# ProximalCore.prox +# ProximalCore.prox! +# ProximalCore.gradient +# ProximalCore.gradient! # ``` # # ## Example: constrained Rosenbrock @@ -33,13 +35,13 @@ rosenbrock2D(x) = 100 * (x[2] - x[1]^2)^2 + (1 - x[1])^2 # outside of the set. using LinearAlgebra -using ProximalOperators +using ProximalCore -struct IndUnitBall <: ProximalOperators.ProximableFunction end +struct IndUnitBall end (::IndUnitBall)(x) = norm(x) > 1 ? eltype(x)(Inf) : eltype(x)(0) -function ProximalOperators.prox!(y, ::IndUnitBall, x, gamma) +function ProximalCore.prox!(y, ::IndUnitBall, x, gamma) if norm(x) > 1 y .= x ./ norm(x) else @@ -73,7 +75,7 @@ scatter!([solution[1]], [solution[2]], color=:red, markershape=:star5, label="co # # We can achieve this by wrapping functions in a dedicated `Counting` type: -mutable struct Counting{T} <: ProximalOperators.ProximableFunction +mutable struct Counting{T} f::T gradient_count::Int prox_count::Int @@ -83,14 +85,14 @@ Counting(f::T) where T = Counting{T}(f, 0, 0) # Now we only need to intercept any call to `gradient!` and `prox!` and increase counters there: -function ProximalOperators.gradient!(y, f::Counting, x) +function ProximalCore.gradient!(y, f::Counting, x) f.gradient_count += 1 - return ProximalOperators.gradient!(y, f.f, x) + return ProximalCore.gradient!(y, f.f, x) end -function ProximalOperators.prox!(y, f::Counting, x, gamma) +function ProximalCore.prox!(y, f::Counting, x, gamma) f.prox_count += 1 - return ProximalOperators.prox!(y, f.f, x, gamma) + return ProximalCore.prox!(y, f.f, x, gamma) end # We can run again the previous example, this time wrapping the objective terms within `Counting`: From ea11a546eb56a5fd365f36a6e567a7cdf1d62e6c Mon Sep 17 00:00:00 2001 From: Lorenzo Stella Date: Fri, 18 Feb 2022 23:22:43 +0100 Subject: [PATCH 5/6] update benchmarks env --- benchmark/Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/benchmark/Project.toml b/benchmark/Project.toml index 278cee3..c7c039a 100644 --- a/benchmark/Project.toml +++ b/benchmark/Project.toml @@ -8,4 +8,4 @@ ProximalAlgorithms = "140ffc9f-1907-541a-a177-7475e0a401e9" ProximalOperators = "a725b495-10eb-56fe-b38b-717eba820537" [compat] -ProximalOperators = "0.14" +ProximalOperators = "0.15" From 1ab97dc08dd1a0202a2815cb49e2e24c0e8584e8 Mon Sep 17 00:00:00 2001 From: Lorenzo Stella Date: Sat, 19 Feb 2022 08:52:52 +0100 Subject: [PATCH 6/6] minor fix --- docs/Project.toml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/docs/Project.toml b/docs/Project.toml index a7ed7e2..0fbc02e 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -7,3 +7,6 @@ Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" ProximalAlgorithms = "140ffc9f-1907-541a-a177-7475e0a401e9" ProximalCore = "dc4f5ac2-75d1-4f31-931e-60435d74994b" ProximalOperators = "a725b495-10eb-56fe-b38b-717eba820537" + +[compat] +ProximalOperators = "0.15"