From 96fd0c2a981cb278dd0855cc07792c7190acab02 Mon Sep 17 00:00:00 2001 From: Lorenzo Stella Date: Wed, 20 Oct 2021 08:43:12 +0200 Subject: [PATCH] Add autodiff fallback for gradient computation using Zygote (#52) --- .github/workflows/ci.yml | 2 - Project.toml | 13 ++----- src/ProximalAlgorithms.jl | 5 ++- src/compat.jl | 5 --- src/utilities/ad.jl | 14 +++++++ src/utilities/conjugate.jl | 4 +- test/Project.toml | 8 ++++ test/accel/{anderson.jl => test_anderson.jl} | 0 test/accel/{broyden.jl => test_broyden.jl} | 0 test/accel/{lbfgs.jl => test_lbfgs.jl} | 0 test/accel/{nesterov.jl => test_nesterov.jl} | 0 test/accel/{noaccel.jl => test_noaccel.jl} | 0 test/runtests.jl | 17 ++++---- test/utilities/test_ad.jl | 39 +++++++++++++++++++ .../{conjugate.jl => test_conjugate.jl} | 2 + .../{fb_tools.jl => test_fb_tools.jl} | 0 ...ation_tools.jl => test_iteration_tools.jl} | 0 17 files changed, 80 insertions(+), 29 deletions(-) delete mode 100644 src/compat.jl create mode 100644 src/utilities/ad.jl create mode 100644 test/Project.toml rename test/accel/{anderson.jl => test_anderson.jl} (100%) rename test/accel/{broyden.jl => test_broyden.jl} (100%) rename test/accel/{lbfgs.jl => test_lbfgs.jl} (100%) rename test/accel/{nesterov.jl => test_nesterov.jl} (100%) rename test/accel/{noaccel.jl => test_noaccel.jl} (100%) create mode 100644 test/utilities/test_ad.jl rename test/utilities/{conjugate.jl => test_conjugate.jl} (98%) rename test/utilities/{fb_tools.jl => test_fb_tools.jl} (100%) rename test/utilities/{iteration_tools.jl => test_iteration_tools.jl} (100%) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 91f3297..990ed1b 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -14,7 +14,6 @@ jobs: fail-fast: false matrix: version: - - '1.1' - '1.5' - '1.6' os: @@ -34,4 +33,3 @@ jobs: - uses: julia-actions/julia-uploadcodecov@latest env: CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} - diff --git a/Project.toml b/Project.toml index a288dea..46bcc01 100644 --- a/Project.toml +++ b/Project.toml @@ -6,16 +6,9 @@ version = "0.5.0" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" ProximalOperators = "a725b495-10eb-56fe-b38b-717eba820537" +Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] ProximalOperators = "0.14" -julia = "1.1.0" - -[extras] -AbstractOperators = "d9c5613a-d543-52d8-9afd-8f241a8c3f1c" -Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" -RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd" -Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" - -[targets] -test = ["Random", "Test", "RecursiveArrayTools", "AbstractOperators"] +Zygote = "0.6" +julia = "1.2" diff --git a/src/ProximalAlgorithms.jl b/src/ProximalAlgorithms.jl index bc20634..ea42d89 100644 --- a/src/ProximalAlgorithms.jl +++ b/src/ProximalAlgorithms.jl @@ -1,12 +1,13 @@ module ProximalAlgorithms +using ProximalOperators + const RealOrComplex{R} = Union{R,Complex{R}} const Maybe{T} = Union{T,Nothing} -include("compat.jl") - # utilities +include("utilities/ad.jl") include("utilities/conjugate.jl") include("utilities/fb_tools.jl") include("utilities/iteration_tools.jl") diff --git a/src/compat.jl b/src/compat.jl deleted file mode 100644 index 5bec281..0000000 --- a/src/compat.jl +++ /dev/null @@ -1,5 +0,0 @@ -if VERSION < v"1.1" - using LinearAlgebra - LinearAlgebra.mul!(C::AbstractVecOrMat, J::UniformScaling, B::AbstractVecOrMat) = - mul!(C, J.λ, B) -end diff --git a/src/utilities/ad.jl b/src/utilities/ad.jl new file mode 100644 index 0000000..a25d1a8 --- /dev/null +++ b/src/utilities/ad.jl @@ -0,0 +1,14 @@ +using Zygote: pullback +using ProximalOperators + +function ProximalOperators.gradient(f, x) + fx, pb = pullback(f, x) + grad = pb(one(fx))[1] + return grad, fx +end + +function ProximalOperators.gradient!(grad, f, x) + y, fx = gradient(f, x) + grad .= y + return fx +end diff --git a/src/utilities/conjugate.jl b/src/utilities/conjugate.jl index 70a9081..d831d8a 100644 --- a/src/utilities/conjugate.jl +++ b/src/utilities/conjugate.jl @@ -1,8 +1,8 @@ using ProximalOperators using ProximalOperators: Zero -ProximalOperators.Conjugate(f::Zero) = IndZero() -ProximalOperators.Conjugate(f::IndZero) = Zero() +ProximalOperators.Conjugate(_::Zero) = IndZero() +ProximalOperators.Conjugate(_::IndZero) = Zero() ProximalOperators.Conjugate(f::SqrNormL2) = SqrNormL2(1.0 / f.lambda) # TODO: Add other useful functions and calculus rules such as translation diff --git a/test/Project.toml b/test/Project.toml new file mode 100644 index 0000000..f4cf9bb --- /dev/null +++ b/test/Project.toml @@ -0,0 +1,8 @@ +[deps] +AbstractOperators = "d9c5613a-d543-52d8-9afd-8f241a8c3f1c" +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" +ProximalOperators = "a725b495-10eb-56fe-b38b-717eba820537" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd" +Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" diff --git a/test/accel/anderson.jl b/test/accel/test_anderson.jl similarity index 100% rename from test/accel/anderson.jl rename to test/accel/test_anderson.jl diff --git a/test/accel/broyden.jl b/test/accel/test_broyden.jl similarity index 100% rename from test/accel/broyden.jl rename to test/accel/test_broyden.jl diff --git a/test/accel/lbfgs.jl b/test/accel/test_lbfgs.jl similarity index 100% rename from test/accel/lbfgs.jl rename to test/accel/test_lbfgs.jl diff --git a/test/accel/nesterov.jl b/test/accel/test_nesterov.jl similarity index 100% rename from test/accel/nesterov.jl rename to test/accel/test_nesterov.jl diff --git a/test/accel/noaccel.jl b/test/accel/test_noaccel.jl similarity index 100% rename from test/accel/noaccel.jl rename to test/accel/test_noaccel.jl diff --git a/test/runtests.jl b/test/runtests.jl index 0db8447..837b9d4 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -3,15 +3,16 @@ using Test include("definitions/arraypartition.jl") include("definitions/compose.jl") -include("utilities/iteration_tools.jl") -include("utilities/conjugate.jl") -include("utilities/fb_tools.jl") +include("utilities/test_ad.jl") +include("utilities/test_iteration_tools.jl") +include("utilities/test_conjugate.jl") +include("utilities/test_fb_tools.jl") -include("accel/lbfgs.jl") -include("accel/anderson.jl") -include("accel/nesterov.jl") -include("accel/broyden.jl") -include("accel/noaccel.jl") +include("accel/test_lbfgs.jl") +include("accel/test_anderson.jl") +include("accel/test_nesterov.jl") +include("accel/test_broyden.jl") +include("accel/test_noaccel.jl") include("problems/test_equivalence.jl") include("problems/test_elasticnet.jl") diff --git a/test/utilities/test_ad.jl b/test/utilities/test_ad.jl new file mode 100644 index 0000000..4aa0aef --- /dev/null +++ b/test/utilities/test_ad.jl @@ -0,0 +1,39 @@ +using Test +using LinearAlgebra +using ProximalOperators: NormL1 +using ProximalAlgorithms + +@testset "Autodiff ($T)" for T in [Float32, Float64, ComplexF32, ComplexF64] + R = real(T) + A = T[ + 1.0 -2.0 3.0 -4.0 5.0 + 2.0 -1.0 0.0 -1.0 3.0 + -1.0 0.0 4.0 -3.0 2.0 + -1.0 -1.0 -1.0 1.0 3.0 + ] + b = T[1.0, 2.0, 3.0, 4.0] + f(x) = R(1/2) * norm(A * x - b, 2)^2 + Lf = opnorm(A)^2 + m, n = size(A) + + @testset "Gradient" begin + x = randn(T, n) + gradfx, fx = gradient(f, x) + @test eltype(gradfx) == T + @test typeof(fx) == R + @test gradfx ≈ A' * (A * x - b) + end + + @testset "Algorithms" begin + lam = R(0.1) * norm(A' * b, Inf) + @test typeof(lam) == R + g = NormL1(lam) + x_star = T[-3.877278911564627e-01, 0, 0, 2.174149659863943e-02, 6.168435374149660e-01] + TOL = R(1e-4) + solver = ProximalAlgorithms.FastForwardBackward(tol = TOL) + x, it = solver(zeros(T, n), f = f, g = g, Lf = Lf) + @test eltype(x) == T + @test norm(x - x_star, Inf) <= TOL + @test it < 100 + end +end diff --git a/test/utilities/conjugate.jl b/test/utilities/test_conjugate.jl similarity index 98% rename from test/utilities/conjugate.jl rename to test/utilities/test_conjugate.jl index e025209..1bce562 100644 --- a/test/utilities/conjugate.jl +++ b/test/utilities/test_conjugate.jl @@ -1,3 +1,5 @@ +using Test + @testset "Conjugate" begin using ProximalOperators diff --git a/test/utilities/fb_tools.jl b/test/utilities/test_fb_tools.jl similarity index 100% rename from test/utilities/fb_tools.jl rename to test/utilities/test_fb_tools.jl diff --git a/test/utilities/iteration_tools.jl b/test/utilities/test_iteration_tools.jl similarity index 100% rename from test/utilities/iteration_tools.jl rename to test/utilities/test_iteration_tools.jl