Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introducing ProximalCore #67

Merged
merged 6 commits into from
Feb 19, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@ version = "0.5.0"
[deps]
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
ProximalOperators = "a725b495-10eb-56fe-b38b-717eba820537"
ProximalCore = "dc4f5ac2-75d1-4f31-931e-60435d74994b"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
ProximalOperators = "0.14"
ProximalCore = "0.1"
Zygote = "0.6"
julia = "1.2"
2 changes: 1 addition & 1 deletion benchmark/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,4 @@ ProximalAlgorithms = "140ffc9f-1907-541a-a177-7475e0a401e9"
ProximalOperators = "a725b495-10eb-56fe-b38b-717eba820537"

[compat]
ProximalOperators = "0.14"
ProximalOperators = "0.15"
4 changes: 4 additions & 0 deletions docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,8 @@ HTTP = "cd3eb016-35fb-5094-929b-558a96fad6f3"
Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
ProximalAlgorithms = "140ffc9f-1907-541a-a177-7475e0a401e9"
ProximalCore = "dc4f5ac2-75d1-4f31-931e-60435d74994b"
ProximalOperators = "a725b495-10eb-56fe-b38b-717eba820537"

[compat]
ProximalOperators = "0.15"
5 changes: 3 additions & 2 deletions docs/make.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using Documenter, DocumenterCitations, ProximalAlgorithms
using Documenter, DocumenterCitations
using ProximalAlgorithms, ProximalCore
using Literate

bib = CitationBibliography(joinpath(@__DIR__, "references.bib"))
Expand All @@ -22,7 +23,7 @@ end

makedocs(
bib,
modules=[ProximalAlgorithms],
modules=[ProximalAlgorithms, ProximalCore],
sitename="ProximalAlgorithms.jl",
pages=[
"Home" => "index.md",
Expand Down
34 changes: 18 additions & 16 deletions docs/src/guide/custom_objectives.jl
Original file line number Diff line number Diff line change
@@ -1,25 +1,27 @@
# # [Custom objective terms](@id custom_terms)
#
# ProximalAlgorithms relies on the first-order primitives implemented in [ProximalOperators](https://github.com/JuliaFirstOrder/ProximalOperators.jl):
# while a rich library of function types is provided there, one may need to formulate
# problems using custom objective terms.
# ProximalAlgorithms relies on the first-order primitives defined in [ProximalCore](https://github.com/JuliaFirstOrder/ProximalCore.jl).
# While a rich library of function types, implementing such primitives, is provided by [ProximalOperators](https://github.com/JuliaFirstOrder/ProximalOperators.jl),
# one may need to formulate problems using custom objective terms.
# When that is the case, one only needs to implement the right first-order primitive,
# ``\nabla f`` or ``\operatorname{prox}_{\gamma f}`` or both, for algorithms to be able
# to work with ``f``.
#
# Defining the proximal mapping for a custom function type requires adding a method for [`prox!`](@ref ProximalAlgorithms.prox!).
# Defining the proximal mapping for a custom function type requires adding a method for [`ProximalCore.prox!`](@ref).
#
# To compute gradients, ProximalAlgorithms provides a fallback definition for [`gradient!`](@ref ProximalAlgorithms.gradient!),
# To compute gradients, ProximalAlgorithms provides a fallback definition for [`ProximalCore.gradient!`](@ref),
# relying on [Zygote](https://github.com/FluxML/Zygote.jl) to use automatic differentiation.
# Therefore, you can provide any (differentiable) Julia function wherever gradients need to be taken,
# and everything will work out of the box.
#
# If however one would like to provide their own gradient implementation (e.g. for efficiency reasons),
# they can simply implement a method for [`gradient!`](@ref ProximalAlgorithms.gradient!).
# they can simply implement a method for [`ProximalCore.gradient!`](@ref).
#
# ```@docs
# ProximalAlgorithms.prox!(y, f, x, gamma)
# ProximalAlgorithms.gradient!(g, f, x)
# ProximalCore.prox
# ProximalCore.prox!
# ProximalCore.gradient
# ProximalCore.gradient!
# ```
#
# ## Example: constrained Rosenbrock
Expand All @@ -33,13 +35,13 @@ rosenbrock2D(x) = 100 * (x[2] - x[1]^2)^2 + (1 - x[1])^2
# outside of the set.

using LinearAlgebra
using ProximalOperators
using ProximalCore

struct IndUnitBall <: ProximalOperators.ProximableFunction end
struct IndUnitBall end

(::IndUnitBall)(x) = norm(x) > 1 ? eltype(x)(Inf) : eltype(x)(0)

function ProximalOperators.prox!(y, ::IndUnitBall, x, gamma)
function ProximalCore.prox!(y, ::IndUnitBall, x, gamma)
if norm(x) > 1
y .= x ./ norm(x)
else
Expand Down Expand Up @@ -73,7 +75,7 @@ scatter!([solution[1]], [solution[2]], color=:red, markershape=:star5, label="co
#
# We can achieve this by wrapping functions in a dedicated `Counting` type:

mutable struct Counting{T} <: ProximalOperators.ProximableFunction
mutable struct Counting{T}
f::T
gradient_count::Int
prox_count::Int
Expand All @@ -83,14 +85,14 @@ Counting(f::T) where T = Counting{T}(f, 0, 0)

# Now we only need to intercept any call to `gradient!` and `prox!` and increase counters there:

function ProximalOperators.gradient!(y, f::Counting, x)
function ProximalCore.gradient!(y, f::Counting, x)
f.gradient_count += 1
return ProximalOperators.gradient!(y, f.f, x)
return ProximalCore.gradient!(y, f.f, x)
end

function ProximalOperators.prox!(y, f::Counting, x, gamma)
function ProximalCore.prox!(y, f::Counting, x, gamma)
f.prox_count += 1
return ProximalOperators.prox!(y, f.f, x, gamma)
return ProximalCore.prox!(y, f.f, x, gamma)
end

# We can run again the previous example, this time wrapping the objective terms within `Counting`:
Expand Down
23 changes: 2 additions & 21 deletions src/ProximalAlgorithms.jl
Original file line number Diff line number Diff line change
@@ -1,33 +1,14 @@
module ProximalAlgorithms

using ProximalOperators
import ProximalOperators: prox!, gradient!
using ProximalCore
using ProximalCore: prox, prox!, gradient, gradient!

const RealOrComplex{R} = Union{R,Complex{R}}
const Maybe{T} = Union{T,Nothing}

"""
prox!(y, f, x, gamma)

Compute the proximal mapping of `f` at `x`, with stepsize `gamma`, and store the result in `y`.
Return the value of `f` at `y`.
"""
prox!(y, f, x, gamma)

"""
gradient!(g, f, x)

Compute the gradient of `f` at `x`, and stores it in `y`. Return the value of `f` at `x`.
"""
gradient!(y, f, x)

# TODO move out
ProximalOperators.is_quadratic(::Any) = false

# various utilities

include("utilities/ad.jl")
include("utilities/conjugate.jl")
include("utilities/fb_tools.jl")
include("utilities/iteration_tools.jl")

Expand Down
2 changes: 1 addition & 1 deletion src/algorithms/davis_yin.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# pp. 829–858 (2017).

using Printf
using ProximalOperators: Zero
using ProximalCore: Zero
using LinearAlgebra
using Printf

Expand Down
2 changes: 1 addition & 1 deletion src/algorithms/douglas_rachford.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# Mathematical Programming, vol. 55, no. 1, pp. 293-318 (1989).

using Base.Iterators
using ProximalOperators: Zero
using ProximalCore: Zero
using LinearAlgebra
using Printf

Expand Down
35 changes: 15 additions & 20 deletions src/algorithms/drls.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,24 +6,20 @@

using Base.Iterators
using ProximalAlgorithms.IterationTools
using ProximalOperators: Zero
using ProximalCore: Zero
using LinearAlgebra
using Printf

function drls_default_gamma(f, mf, Lf, alpha, lambda)
function drls_default_gamma(f::Tf, mf, Lf, alpha, lambda) where Tf
if mf !== nothing && mf > 0
return 1 / (alpha * mf)
end
if ProximalOperators.is_convex(f)
return alpha / Lf
else
return alpha * (2 - lambda) / (2 * Lf)
end
return ProximalCore.is_convex(Tf) ? alpha / Lf : alpha * (2 - lambda) / (2 * Lf)
end

function drls_C(f, mf, Lf, gamma, lambda)
function drls_C(f::Tf, mf, Lf, gamma, lambda) where Tf
a = mf === nothing || mf <= 0 ? gamma * Lf : 1 / (gamma * mf)
m = ProximalOperators.is_convex(f) ? max(a - lambda / 2, 0) : 1
m = ProximalCore.is_convex(Tf) ? max(a - lambda / 2, 0) : 1
return (lambda / ((1 + a)^2) * ((2 - lambda) / 2 - a * m))
end

Expand Down Expand Up @@ -121,7 +117,7 @@ update_direction_state!(::NesterovStyle, ::DRLSIteration, state::DRLSState) = re
update_direction_state!(::NoAccelerationStyle, ::DRLSIteration, state::DRLSState) = return
update_direction_state!(iter::DRLSIteration, state::DRLSState) = update_direction_state!(acceleration_style(typeof(iter.directions)), iter, state)

function Base.iterate(iter::DRLSIteration{R}, state::DRLSState) where R
function Base.iterate(iter::DRLSIteration{R, Tx, Tf}, state::DRLSState) where {R, Tx, Tf}
DRE_curr = DRE(state)

set_next_direction!(iter, state)
Expand Down Expand Up @@ -150,16 +146,15 @@ function Base.iterate(iter::DRLSIteration{R}, state::DRLSState) where R
state.tau = k == iter.max_backtracks ? R(0) : state.tau / 2
state.x .= state.tau .* state.x_d .+ (1 - state.tau) .* state.xbar_prev

if k == 1 && ProximalOperators.is_generalized_quadratic(iter.f)
copyto!(state.u1, state.u)
c = prox!(state.u0, iter.f, state.xbar_prev, iter.gamma)
state.temp_x1 .= state.xbar_prev .- state.x_d
state.temp_x2 .= state.xbar_prev .- state.u0
b = real(dot(state.temp_x1, state.temp_x2)) / iter.gamma
a = state.f_u - b - c
end

if ProximalOperators.is_generalized_quadratic(iter.f)
if ProximalCore.is_generalized_quadratic(Tf)
if k == 1
copyto!(state.u1, state.u)
c = prox!(state.u0, iter.f, state.xbar_prev, iter.gamma)
state.temp_x1 .= state.xbar_prev .- state.x_d
state.temp_x2 .= state.xbar_prev .- state.u0
b = real(dot(state.temp_x1, state.temp_x2)) / iter.gamma
a = state.f_u - b - c
end
state.u .= state.tau .* state.u1 .+ (1 - state.tau) .* state.u0
state.f_u = a * state.tau ^ 2 + b * state.tau + c
else
Expand Down
2 changes: 1 addition & 1 deletion src/algorithms/fast_forward_backward.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

using Base.Iterators
using ProximalAlgorithms.IterationTools
using ProximalOperators: Zero
using ProximalCore: Zero
using LinearAlgebra
using Printf

Expand Down
2 changes: 1 addition & 1 deletion src/algorithms/forward_backward.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

using Base.Iterators
using ProximalAlgorithms.IterationTools
using ProximalOperators: Zero
using ProximalCore: Zero
using LinearAlgebra
using Printf

Expand Down
2 changes: 1 addition & 1 deletion src/algorithms/li_lin.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

using Base.Iterators
using ProximalAlgorithms.IterationTools
using ProximalOperators: Zero
using ProximalCore: Zero
using LinearAlgebra
using Printf

Expand Down
6 changes: 3 additions & 3 deletions src/algorithms/panoc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

using Base.Iterators
using ProximalAlgorithms.IterationTools
using ProximalOperators: Zero
using ProximalCore: Zero
using LinearAlgebra
using Printf

Expand Down Expand Up @@ -110,7 +110,7 @@ reset_direction_state!(::QuasiNewtonStyle, ::PANOCIteration, state::PANOCState)
reset_direction_state!(::NoAccelerationStyle, ::PANOCIteration, state::PANOCState) = return
reset_direction_state!(iter::PANOCIteration, state::PANOCState) = reset_direction_state!(acceleration_style(typeof(iter.directions)), iter, state)

function Base.iterate(iter::PANOCIteration{R}, state::PANOCState) where R
function Base.iterate(iter::PANOCIteration{R, Tx, Tf}, state::PANOCState) where {R, Tx, Tf}
f_Az, a, b, c = R(Inf), R(Inf), R(Inf), R(Inf)

f_Az_upp = if iter.adaptive == true
Expand Down Expand Up @@ -177,7 +177,7 @@ function Base.iterate(iter::PANOCIteration{R}, state::PANOCState) where R
state.x .= state.tau .* state.x_d .+ (1 - state.tau) .* state.z_curr
state.Ax .= state.tau .* state.Ax_d .+ (1 - state.tau) .* state.Az

if ProximalOperators.is_quadratic(iter.f)
if ProximalCore.is_generalized_quadratic(Tf)
# in case f is quadratic, we can compute its value and gradient
# along a line using interpolation and linear combinations
# this allows saving operations
Expand Down
2 changes: 1 addition & 1 deletion src/algorithms/panocplus.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

using Base.Iterators
using ProximalAlgorithms.IterationTools
using ProximalOperators: Zero
using ProximalCore: Zero
using LinearAlgebra
using Printf

Expand Down
6 changes: 3 additions & 3 deletions src/algorithms/primal_dual.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

using Base.Iterators
using ProximalAlgorithms.IterationTools
using ProximalOperators: Zero
using ProximalCore: Zero, IndZero, convex_conjugate
using LinearAlgebra
using Printf

Expand Down Expand Up @@ -175,13 +175,13 @@ function Base.iterate(iter::AFBAIteration, state::AFBAState = AFBAState(x=copy(i
prox!(state.xbar, iter.g, state.temp_x, iter.gamma[1])

# perform ybar-update step
gradient!(state.gradl, Conjugate(iter.l), state.y)
gradient!(state.gradl, convex_conjugate(iter.l), state.y)
state.temp_x .= iter.theta .* state.xbar .+ (1 - iter.theta) .* state.x
mul!(state.temp_y, iter.L, state.temp_x)
state.temp_y .-= state.gradl
state.temp_y .*= iter.gamma[2]
state.temp_y .+= state.y
prox!(state.ybar, Conjugate(iter.h), state.temp_y, iter.gamma[2])
prox!(state.ybar, convex_conjugate(iter.h), state.temp_y, iter.gamma[2])

# the residues
state.FPR_x .= state.xbar .- state.x
Expand Down
2 changes: 1 addition & 1 deletion src/algorithms/sfista.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

using Base.Iterators
using ProximalAlgorithms.IterationTools
using ProximalOperators: Zero
using ProximalCore: Zero
using LinearAlgebra
using Printf

Expand Down
2 changes: 1 addition & 1 deletion src/algorithms/zerofpr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

using Base.Iterators
using ProximalAlgorithms.IterationTools
using ProximalOperators: Zero
using ProximalCore: Zero
using LinearAlgebra
using Printf

Expand Down
12 changes: 3 additions & 9 deletions src/utilities/ad.jl
Original file line number Diff line number Diff line change
@@ -1,14 +1,8 @@
using Zygote: pullback
using ProximalOperators
using ProximalCore

function ProximalOperators.gradient(f, x)
function ProximalCore.gradient!(grad, f, x)
fx, pb = pullback(f, x)
grad = pb(one(fx))[1]
return grad, fx
end

function ProximalOperators.gradient!(grad, f, x)
y, fx = gradient(f, x)
grad .= y
grad .= pb(one(fx))[1]
return fx
end
8 changes: 0 additions & 8 deletions src/utilities/conjugate.jl

This file was deleted.

1 change: 1 addition & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
AbstractOperators = "d9c5613a-d543-52d8-9afd-8f241a8c3f1c"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
ProximalCore = "dc4f5ac2-75d1-4f31-931e-60435d74994b"
ProximalOperators = "a725b495-10eb-56fe-b38b-717eba820537"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
Expand Down
Loading