Skip to content

Commit

Permalink
Add autodiff fallback for gradient computation using Zygote (#52)
Browse files Browse the repository at this point in the history
  • Loading branch information
lostella authored Oct 20, 2021
1 parent ab341f2 commit 96fd0c2
Show file tree
Hide file tree
Showing 17 changed files with 80 additions and 29 deletions.
2 changes: 0 additions & 2 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ jobs:
fail-fast: false
matrix:
version:
- '1.1'
- '1.5'
- '1.6'
os:
Expand All @@ -34,4 +33,3 @@ jobs:
- uses: julia-actions/julia-uploadcodecov@latest
env:
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}

13 changes: 3 additions & 10 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,9 @@ version = "0.5.0"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
ProximalOperators = "a725b495-10eb-56fe-b38b-717eba820537"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
ProximalOperators = "0.14"
julia = "1.1.0"

[extras]
AbstractOperators = "d9c5613a-d543-52d8-9afd-8f241a8c3f1c"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Random", "Test", "RecursiveArrayTools", "AbstractOperators"]
Zygote = "0.6"
julia = "1.2"
5 changes: 3 additions & 2 deletions src/ProximalAlgorithms.jl
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
module ProximalAlgorithms

using ProximalOperators

const RealOrComplex{R} = Union{R,Complex{R}}
const Maybe{T} = Union{T,Nothing}

include("compat.jl")

# utilities

include("utilities/ad.jl")
include("utilities/conjugate.jl")
include("utilities/fb_tools.jl")
include("utilities/iteration_tools.jl")
Expand Down
5 changes: 0 additions & 5 deletions src/compat.jl

This file was deleted.

14 changes: 14 additions & 0 deletions src/utilities/ad.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
using Zygote: pullback
using ProximalOperators

function ProximalOperators.gradient(f, x)
fx, pb = pullback(f, x)
grad = pb(one(fx))[1]
return grad, fx
end

function ProximalOperators.gradient!(grad, f, x)
y, fx = gradient(f, x)
grad .= y
return fx
end
4 changes: 2 additions & 2 deletions src/utilities/conjugate.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
using ProximalOperators
using ProximalOperators: Zero

ProximalOperators.Conjugate(f::Zero) = IndZero()
ProximalOperators.Conjugate(f::IndZero) = Zero()
ProximalOperators.Conjugate(_::Zero) = IndZero()
ProximalOperators.Conjugate(_::IndZero) = Zero()
ProximalOperators.Conjugate(f::SqrNormL2) = SqrNormL2(1.0 / f.lambda)

# TODO: Add other useful functions and calculus rules such as translation
8 changes: 8 additions & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
[deps]
AbstractOperators = "d9c5613a-d543-52d8-9afd-8f241a8c3f1c"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
ProximalOperators = "a725b495-10eb-56fe-b38b-717eba820537"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
17 changes: 9 additions & 8 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,16 @@ using Test
include("definitions/arraypartition.jl")
include("definitions/compose.jl")

include("utilities/iteration_tools.jl")
include("utilities/conjugate.jl")
include("utilities/fb_tools.jl")
include("utilities/test_ad.jl")
include("utilities/test_iteration_tools.jl")
include("utilities/test_conjugate.jl")
include("utilities/test_fb_tools.jl")

include("accel/lbfgs.jl")
include("accel/anderson.jl")
include("accel/nesterov.jl")
include("accel/broyden.jl")
include("accel/noaccel.jl")
include("accel/test_lbfgs.jl")
include("accel/test_anderson.jl")
include("accel/test_nesterov.jl")
include("accel/test_broyden.jl")
include("accel/test_noaccel.jl")

include("problems/test_equivalence.jl")
include("problems/test_elasticnet.jl")
Expand Down
39 changes: 39 additions & 0 deletions test/utilities/test_ad.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
using Test
using LinearAlgebra
using ProximalOperators: NormL1
using ProximalAlgorithms

@testset "Autodiff ($T)" for T in [Float32, Float64, ComplexF32, ComplexF64]
R = real(T)
A = T[
1.0 -2.0 3.0 -4.0 5.0
2.0 -1.0 0.0 -1.0 3.0
-1.0 0.0 4.0 -3.0 2.0
-1.0 -1.0 -1.0 1.0 3.0
]
b = T[1.0, 2.0, 3.0, 4.0]
f(x) = R(1/2) * norm(A * x - b, 2)^2
Lf = opnorm(A)^2
m, n = size(A)

@testset "Gradient" begin
x = randn(T, n)
gradfx, fx = gradient(f, x)
@test eltype(gradfx) == T
@test typeof(fx) == R
@test gradfx A' * (A * x - b)
end

@testset "Algorithms" begin
lam = R(0.1) * norm(A' * b, Inf)
@test typeof(lam) == R
g = NormL1(lam)
x_star = T[-3.877278911564627e-01, 0, 0, 2.174149659863943e-02, 6.168435374149660e-01]
TOL = R(1e-4)
solver = ProximalAlgorithms.FastForwardBackward(tol = TOL)
x, it = solver(zeros(T, n), f = f, g = g, Lf = Lf)
@test eltype(x) == T
@test norm(x - x_star, Inf) <= TOL
@test it < 100
end
end
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
using Test

@testset "Conjugate" begin

using ProximalOperators
Expand Down
File renamed without changes.
File renamed without changes.

0 comments on commit 96fd0c2

Please sign in to comment.