Skip to content

Commit

Permalink
Introducing ProximalCore (#67)
Browse files Browse the repository at this point in the history
  • Loading branch information
lostella authored Feb 19, 2022
1 parent 69e6fc2 commit 9b5c3ae
Show file tree
Hide file tree
Showing 25 changed files with 77 additions and 155 deletions.
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@ version = "0.5.0"
[deps]
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
ProximalOperators = "a725b495-10eb-56fe-b38b-717eba820537"
ProximalCore = "dc4f5ac2-75d1-4f31-931e-60435d74994b"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
ProximalOperators = "0.14"
ProximalCore = "0.1"
Zygote = "0.6"
julia = "1.2"
2 changes: 1 addition & 1 deletion benchmark/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,4 @@ ProximalAlgorithms = "140ffc9f-1907-541a-a177-7475e0a401e9"
ProximalOperators = "a725b495-10eb-56fe-b38b-717eba820537"

[compat]
ProximalOperators = "0.14"
ProximalOperators = "0.15"
4 changes: 4 additions & 0 deletions docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,8 @@ HTTP = "cd3eb016-35fb-5094-929b-558a96fad6f3"
Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
ProximalAlgorithms = "140ffc9f-1907-541a-a177-7475e0a401e9"
ProximalCore = "dc4f5ac2-75d1-4f31-931e-60435d74994b"
ProximalOperators = "a725b495-10eb-56fe-b38b-717eba820537"

[compat]
ProximalOperators = "0.15"
5 changes: 3 additions & 2 deletions docs/make.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using Documenter, DocumenterCitations, ProximalAlgorithms
using Documenter, DocumenterCitations
using ProximalAlgorithms, ProximalCore
using Literate

bib = CitationBibliography(joinpath(@__DIR__, "references.bib"))
Expand All @@ -22,7 +23,7 @@ end

makedocs(
bib,
modules=[ProximalAlgorithms],
modules=[ProximalAlgorithms, ProximalCore],
sitename="ProximalAlgorithms.jl",
pages=[
"Home" => "index.md",
Expand Down
34 changes: 18 additions & 16 deletions docs/src/guide/custom_objectives.jl
Original file line number Diff line number Diff line change
@@ -1,25 +1,27 @@
# # [Custom objective terms](@id custom_terms)
#
# ProximalAlgorithms relies on the first-order primitives implemented in [ProximalOperators](https://github.com/JuliaFirstOrder/ProximalOperators.jl):
# while a rich library of function types is provided there, one may need to formulate
# problems using custom objective terms.
# ProximalAlgorithms relies on the first-order primitives defined in [ProximalCore](https://github.com/JuliaFirstOrder/ProximalCore.jl).
# While a rich library of function types, implementing such primitives, is provided by [ProximalOperators](https://github.com/JuliaFirstOrder/ProximalOperators.jl),
# one may need to formulate problems using custom objective terms.
# When that is the case, one only needs to implement the right first-order primitive,
# ``\nabla f`` or ``\operatorname{prox}_{\gamma f}`` or both, for algorithms to be able
# to work with ``f``.
#
# Defining the proximal mapping for a custom function type requires adding a method for [`prox!`](@ref ProximalAlgorithms.prox!).
# Defining the proximal mapping for a custom function type requires adding a method for [`ProximalCore.prox!`](@ref).
#
# To compute gradients, ProximalAlgorithms provides a fallback definition for [`gradient!`](@ref ProximalAlgorithms.gradient!),
# To compute gradients, ProximalAlgorithms provides a fallback definition for [`ProximalCore.gradient!`](@ref),
# relying on [Zygote](https://github.com/FluxML/Zygote.jl) to use automatic differentiation.
# Therefore, you can provide any (differentiable) Julia function wherever gradients need to be taken,
# and everything will work out of the box.
#
# If however one would like to provide their own gradient implementation (e.g. for efficiency reasons),
# they can simply implement a method for [`gradient!`](@ref ProximalAlgorithms.gradient!).
# they can simply implement a method for [`ProximalCore.gradient!`](@ref).
#
# ```@docs
# ProximalAlgorithms.prox!(y, f, x, gamma)
# ProximalAlgorithms.gradient!(g, f, x)
# ProximalCore.prox
# ProximalCore.prox!
# ProximalCore.gradient
# ProximalCore.gradient!
# ```
#
# ## Example: constrained Rosenbrock
Expand All @@ -33,13 +35,13 @@ rosenbrock2D(x) = 100 * (x[2] - x[1]^2)^2 + (1 - x[1])^2
# outside of the set.

using LinearAlgebra
using ProximalOperators
using ProximalCore

struct IndUnitBall <: ProximalOperators.ProximableFunction end
struct IndUnitBall end

(::IndUnitBall)(x) = norm(x) > 1 ? eltype(x)(Inf) : eltype(x)(0)

function ProximalOperators.prox!(y, ::IndUnitBall, x, gamma)
function ProximalCore.prox!(y, ::IndUnitBall, x, gamma)
if norm(x) > 1
y .= x ./ norm(x)
else
Expand Down Expand Up @@ -73,7 +75,7 @@ scatter!([solution[1]], [solution[2]], color=:red, markershape=:star5, label="co
#
# We can achieve this by wrapping functions in a dedicated `Counting` type:

mutable struct Counting{T} <: ProximalOperators.ProximableFunction
mutable struct Counting{T}
f::T
gradient_count::Int
prox_count::Int
Expand All @@ -83,14 +85,14 @@ Counting(f::T) where T = Counting{T}(f, 0, 0)

# Now we only need to intercept any call to `gradient!` and `prox!` and increase counters there:

function ProximalOperators.gradient!(y, f::Counting, x)
function ProximalCore.gradient!(y, f::Counting, x)
f.gradient_count += 1
return ProximalOperators.gradient!(y, f.f, x)
return ProximalCore.gradient!(y, f.f, x)
end

function ProximalOperators.prox!(y, f::Counting, x, gamma)
function ProximalCore.prox!(y, f::Counting, x, gamma)
f.prox_count += 1
return ProximalOperators.prox!(y, f.f, x, gamma)
return ProximalCore.prox!(y, f.f, x, gamma)
end

# We can run again the previous example, this time wrapping the objective terms within `Counting`:
Expand Down
23 changes: 2 additions & 21 deletions src/ProximalAlgorithms.jl
Original file line number Diff line number Diff line change
@@ -1,33 +1,14 @@
module ProximalAlgorithms

using ProximalOperators
import ProximalOperators: prox!, gradient!
using ProximalCore
using ProximalCore: prox, prox!, gradient, gradient!

const RealOrComplex{R} = Union{R,Complex{R}}
const Maybe{T} = Union{T,Nothing}

"""
prox!(y, f, x, gamma)
Compute the proximal mapping of `f` at `x`, with stepsize `gamma`, and store the result in `y`.
Return the value of `f` at `y`.
"""
prox!(y, f, x, gamma)

"""
gradient!(g, f, x)
Compute the gradient of `f` at `x`, and stores it in `y`. Return the value of `f` at `x`.
"""
gradient!(y, f, x)

# TODO move out
ProximalOperators.is_quadratic(::Any) = false

# various utilities

include("utilities/ad.jl")
include("utilities/conjugate.jl")
include("utilities/fb_tools.jl")
include("utilities/iteration_tools.jl")

Expand Down
2 changes: 1 addition & 1 deletion src/algorithms/davis_yin.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# pp. 829–858 (2017).

using Printf
using ProximalOperators: Zero
using ProximalCore: Zero
using LinearAlgebra
using Printf

Expand Down
2 changes: 1 addition & 1 deletion src/algorithms/douglas_rachford.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# Mathematical Programming, vol. 55, no. 1, pp. 293-318 (1989).

using Base.Iterators
using ProximalOperators: Zero
using ProximalCore: Zero
using LinearAlgebra
using Printf

Expand Down
35 changes: 15 additions & 20 deletions src/algorithms/drls.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,24 +6,20 @@

using Base.Iterators
using ProximalAlgorithms.IterationTools
using ProximalOperators: Zero
using ProximalCore: Zero
using LinearAlgebra
using Printf

function drls_default_gamma(f, mf, Lf, alpha, lambda)
function drls_default_gamma(f::Tf, mf, Lf, alpha, lambda) where Tf
if mf !== nothing && mf > 0
return 1 / (alpha * mf)
end
if ProximalOperators.is_convex(f)
return alpha / Lf
else
return alpha * (2 - lambda) / (2 * Lf)
end
return ProximalCore.is_convex(Tf) ? alpha / Lf : alpha * (2 - lambda) / (2 * Lf)
end

function drls_C(f, mf, Lf, gamma, lambda)
function drls_C(f::Tf, mf, Lf, gamma, lambda) where Tf
a = mf === nothing || mf <= 0 ? gamma * Lf : 1 / (gamma * mf)
m = ProximalOperators.is_convex(f) ? max(a - lambda / 2, 0) : 1
m = ProximalCore.is_convex(Tf) ? max(a - lambda / 2, 0) : 1
return (lambda / ((1 + a)^2) * ((2 - lambda) / 2 - a * m))
end

Expand Down Expand Up @@ -121,7 +117,7 @@ update_direction_state!(::NesterovStyle, ::DRLSIteration, state::DRLSState) = re
update_direction_state!(::NoAccelerationStyle, ::DRLSIteration, state::DRLSState) = return
update_direction_state!(iter::DRLSIteration, state::DRLSState) = update_direction_state!(acceleration_style(typeof(iter.directions)), iter, state)

function Base.iterate(iter::DRLSIteration{R}, state::DRLSState) where R
function Base.iterate(iter::DRLSIteration{R, Tx, Tf}, state::DRLSState) where {R, Tx, Tf}
DRE_curr = DRE(state)

set_next_direction!(iter, state)
Expand Down Expand Up @@ -150,16 +146,15 @@ function Base.iterate(iter::DRLSIteration{R}, state::DRLSState) where R
state.tau = k == iter.max_backtracks ? R(0) : state.tau / 2
state.x .= state.tau .* state.x_d .+ (1 - state.tau) .* state.xbar_prev

if k == 1 && ProximalOperators.is_generalized_quadratic(iter.f)
copyto!(state.u1, state.u)
c = prox!(state.u0, iter.f, state.xbar_prev, iter.gamma)
state.temp_x1 .= state.xbar_prev .- state.x_d
state.temp_x2 .= state.xbar_prev .- state.u0
b = real(dot(state.temp_x1, state.temp_x2)) / iter.gamma
a = state.f_u - b - c
end

if ProximalOperators.is_generalized_quadratic(iter.f)
if ProximalCore.is_generalized_quadratic(Tf)
if k == 1
copyto!(state.u1, state.u)
c = prox!(state.u0, iter.f, state.xbar_prev, iter.gamma)
state.temp_x1 .= state.xbar_prev .- state.x_d
state.temp_x2 .= state.xbar_prev .- state.u0
b = real(dot(state.temp_x1, state.temp_x2)) / iter.gamma
a = state.f_u - b - c
end
state.u .= state.tau .* state.u1 .+ (1 - state.tau) .* state.u0
state.f_u = a * state.tau ^ 2 + b * state.tau + c
else
Expand Down
2 changes: 1 addition & 1 deletion src/algorithms/fast_forward_backward.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

using Base.Iterators
using ProximalAlgorithms.IterationTools
using ProximalOperators: Zero
using ProximalCore: Zero
using LinearAlgebra
using Printf

Expand Down
2 changes: 1 addition & 1 deletion src/algorithms/forward_backward.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

using Base.Iterators
using ProximalAlgorithms.IterationTools
using ProximalOperators: Zero
using ProximalCore: Zero
using LinearAlgebra
using Printf

Expand Down
2 changes: 1 addition & 1 deletion src/algorithms/li_lin.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

using Base.Iterators
using ProximalAlgorithms.IterationTools
using ProximalOperators: Zero
using ProximalCore: Zero
using LinearAlgebra
using Printf

Expand Down
6 changes: 3 additions & 3 deletions src/algorithms/panoc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

using Base.Iterators
using ProximalAlgorithms.IterationTools
using ProximalOperators: Zero
using ProximalCore: Zero
using LinearAlgebra
using Printf

Expand Down Expand Up @@ -110,7 +110,7 @@ reset_direction_state!(::QuasiNewtonStyle, ::PANOCIteration, state::PANOCState)
reset_direction_state!(::NoAccelerationStyle, ::PANOCIteration, state::PANOCState) = return
reset_direction_state!(iter::PANOCIteration, state::PANOCState) = reset_direction_state!(acceleration_style(typeof(iter.directions)), iter, state)

function Base.iterate(iter::PANOCIteration{R}, state::PANOCState) where R
function Base.iterate(iter::PANOCIteration{R, Tx, Tf}, state::PANOCState) where {R, Tx, Tf}
f_Az, a, b, c = R(Inf), R(Inf), R(Inf), R(Inf)

f_Az_upp = if iter.adaptive == true
Expand Down Expand Up @@ -177,7 +177,7 @@ function Base.iterate(iter::PANOCIteration{R}, state::PANOCState) where R
state.x .= state.tau .* state.x_d .+ (1 - state.tau) .* state.z_curr
state.Ax .= state.tau .* state.Ax_d .+ (1 - state.tau) .* state.Az

if ProximalOperators.is_quadratic(iter.f)
if ProximalCore.is_generalized_quadratic(Tf)
# in case f is quadratic, we can compute its value and gradient
# along a line using interpolation and linear combinations
# this allows saving operations
Expand Down
2 changes: 1 addition & 1 deletion src/algorithms/panocplus.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

using Base.Iterators
using ProximalAlgorithms.IterationTools
using ProximalOperators: Zero
using ProximalCore: Zero
using LinearAlgebra
using Printf

Expand Down
6 changes: 3 additions & 3 deletions src/algorithms/primal_dual.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

using Base.Iterators
using ProximalAlgorithms.IterationTools
using ProximalOperators: Zero
using ProximalCore: Zero, IndZero, convex_conjugate
using LinearAlgebra
using Printf

Expand Down Expand Up @@ -175,13 +175,13 @@ function Base.iterate(iter::AFBAIteration, state::AFBAState = AFBAState(x=copy(i
prox!(state.xbar, iter.g, state.temp_x, iter.gamma[1])

# perform ybar-update step
gradient!(state.gradl, Conjugate(iter.l), state.y)
gradient!(state.gradl, convex_conjugate(iter.l), state.y)
state.temp_x .= iter.theta .* state.xbar .+ (1 - iter.theta) .* state.x
mul!(state.temp_y, iter.L, state.temp_x)
state.temp_y .-= state.gradl
state.temp_y .*= iter.gamma[2]
state.temp_y .+= state.y
prox!(state.ybar, Conjugate(iter.h), state.temp_y, iter.gamma[2])
prox!(state.ybar, convex_conjugate(iter.h), state.temp_y, iter.gamma[2])

# the residues
state.FPR_x .= state.xbar .- state.x
Expand Down
2 changes: 1 addition & 1 deletion src/algorithms/sfista.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

using Base.Iterators
using ProximalAlgorithms.IterationTools
using ProximalOperators: Zero
using ProximalCore: Zero
using LinearAlgebra
using Printf

Expand Down
2 changes: 1 addition & 1 deletion src/algorithms/zerofpr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

using Base.Iterators
using ProximalAlgorithms.IterationTools
using ProximalOperators: Zero
using ProximalCore: Zero
using LinearAlgebra
using Printf

Expand Down
12 changes: 3 additions & 9 deletions src/utilities/ad.jl
Original file line number Diff line number Diff line change
@@ -1,14 +1,8 @@
using Zygote: pullback
using ProximalOperators
using ProximalCore

function ProximalOperators.gradient(f, x)
function ProximalCore.gradient!(grad, f, x)
fx, pb = pullback(f, x)
grad = pb(one(fx))[1]
return grad, fx
end

function ProximalOperators.gradient!(grad, f, x)
y, fx = gradient(f, x)
grad .= y
grad .= pb(one(fx))[1]
return fx
end
8 changes: 0 additions & 8 deletions src/utilities/conjugate.jl

This file was deleted.

1 change: 1 addition & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
AbstractOperators = "d9c5613a-d543-52d8-9afd-8f241a8c3f1c"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
ProximalCore = "dc4f5ac2-75d1-4f31-931e-60435d74994b"
ProximalOperators = "a725b495-10eb-56fe-b38b-717eba820537"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
Expand Down
Loading

2 comments on commit 9b5c3ae

@lostella
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/54960

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.5.0 -m "<description of version>" 9b5c3ae2d8b846f5256556ea3daeb06e9def9285
git push origin v0.5.0

Please sign in to comment.