From 7ed57dac28cecd8f662da1bb85180751d90118fa Mon Sep 17 00:00:00 2001 From: Mateusz Baran Date: Mon, 20 Sep 2021 12:48:09 +0200 Subject: [PATCH 01/10] Support for Zygote and ReverseDiff gradients --- Project.toml | 2 +- src/Manifolds.jl | 19 ++++-- src/{ => differentiation}/differentiation.jl | 0 src/{ => differentiation}/finite_diff.jl | 0 src/{ => differentiation}/forward_diff.jl | 0 src/differentiation/reverse_diff.jl | 11 ++++ src/{ => differentiation}/riemannian_diff.jl | 0 src/differentiation/zygote.jl | 11 ++++ test/differentiation.jl | 63 ++++++++++++++++---- 9 files changed, 89 insertions(+), 17 deletions(-) rename src/{ => differentiation}/differentiation.jl (100%) rename src/{ => differentiation}/finite_diff.jl (100%) rename src/{ => differentiation}/forward_diff.jl (100%) create mode 100644 src/differentiation/reverse_diff.jl rename src/{ => differentiation}/riemannian_diff.jl (100%) create mode 100644 src/differentiation/zygote.jl diff --git a/Project.toml b/Project.toml index f44a2fbab4..46cfe7b91b 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Manifolds" uuid = "1cead3c2-87b3-11e9-0ccd-23c62b72b94e" authors = ["Seth Axen ", "Mateusz Baran ", "Ronny Bergmann ", "Antoine Levitt "] -version = "0.6.7" +version = "0.6.8" [deps] Colors = "5ae59095-9a9b-59fe-a467-6f913c188581" diff --git a/src/Manifolds.jl b/src/Manifolds.jl index ddf474c2ef..056dfa9a22 100644 --- a/src/Manifolds.jl +++ b/src/Manifolds.jl @@ -158,8 +158,8 @@ using RecursiveArrayTools: ArrayPartition include("utils.jl") include("product_representations.jl") -include("differentiation.jl") -include("riemannian_diff.jl") +include("differentiation/differentiation.jl") +include("differentiation/riemannian_diff.jl") # Main Meta Manifolds include("manifolds/ConnectionManifold.jl") @@ -285,12 +285,12 @@ end function __init__() @require FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41" begin using .FiniteDiff - include("finite_diff.jl") + include("differentiation/finite_diff.jl") end @require ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" begin using .ForwardDiff - include("forward_diff.jl") + include("differentiation/forward_diff.jl") end @require OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed" begin @@ -303,6 +303,11 @@ function __init__() include("nlsolve.jl") end + @require ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" begin + using .ReverseDiff: ReverseDiff + include("differentiation/reverse_diff.jl") + end + @require Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" begin using .Test: Test include("tests/tests_general.jl") @@ -333,6 +338,12 @@ function __init__() include("recipes.jl") end end + + @require Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" begin + using .Zygote: Zygote + include("differentiation/zygote.jl") + end + return nothing end diff --git a/src/differentiation.jl b/src/differentiation/differentiation.jl similarity index 100% rename from src/differentiation.jl rename to src/differentiation/differentiation.jl diff --git a/src/finite_diff.jl b/src/differentiation/finite_diff.jl similarity index 100% rename from src/finite_diff.jl rename to src/differentiation/finite_diff.jl diff --git a/src/forward_diff.jl b/src/differentiation/forward_diff.jl similarity index 100% rename from src/forward_diff.jl rename to src/differentiation/forward_diff.jl diff --git a/src/differentiation/reverse_diff.jl b/src/differentiation/reverse_diff.jl new file mode 100644 index 0000000000..851bf89a2d --- /dev/null +++ b/src/differentiation/reverse_diff.jl @@ -0,0 +1,11 @@ +struct ReverseDiffBackend <: AbstractDiffBackend end + +function Manifolds._gradient(f, p, ::ReverseDiffBackend) + return ReverseDiff.gradient(f, p) +end + +function Manifolds._gradient!(f, X, p, ::ReverseDiffBackend) + return ReverseDiff.gradient!(X, f, p) +end + +push!(Manifolds._diff_backends, ReverseDiffBackend()) diff --git a/src/riemannian_diff.jl b/src/differentiation/riemannian_diff.jl similarity index 100% rename from src/riemannian_diff.jl rename to src/differentiation/riemannian_diff.jl diff --git a/src/differentiation/zygote.jl b/src/differentiation/zygote.jl new file mode 100644 index 0000000000..f8a2e70d16 --- /dev/null +++ b/src/differentiation/zygote.jl @@ -0,0 +1,11 @@ +struct ZygoteDiffBackend <: AbstractDiffBackend end + +function Manifolds._gradient(f, p, ::ZygoteDiffBackend) + return Zygote.gradient(f, p) +end + +function Manifolds._gradient!(f, X, p, ::ZygoteDiffBackend) + return Zygote.gradient!(X, f, p) +end + +push!(Manifolds._diff_backends, ZygoteDiffBackend()) diff --git a/test/differentiation.jl b/test/differentiation.jl index f6e2bd2992..b060d67962 100644 --- a/test/differentiation.jl +++ b/test/differentiation.jl @@ -17,7 +17,7 @@ using LinearAlgebra: Diagonal, dot fd51 = Manifolds.FiniteDifferencesBackend() @testset "diff_backend" begin @test diff_backend() isa Manifolds.FiniteDifferencesBackend - @test length(diff_backends()) == 2 + @test length(diff_backends()) == 3 @test diff_backends()[1] isa Manifolds.FiniteDifferencesBackend @test length(fd51.method.grid) == 5 @@ -33,7 +33,7 @@ using LinearAlgebra: Diagonal, dot fwd_diff = Manifolds.ForwardDiffBackend() @testset "ForwardDiff" begin @test diff_backend() isa Manifolds.FiniteDifferencesBackend - @test length(diff_backends()) == 2 + @test length(diff_backends()) == 3 @test diff_backends()[1] isa Manifolds.FiniteDifferencesBackend @test diff_backends()[2] == fwd_diff @@ -52,8 +52,8 @@ using LinearAlgebra: Diagonal, dot finite_diff = Manifolds.FiniteDiffBackend() @testset "FiniteDiff" begin @test diff_backend() isa Manifolds.FiniteDifferencesBackend - @test length(diff_backends()) == 3 - @test diff_backends()[3] == finite_diff + @test length(diff_backends()) == 4 + @test diff_backends()[4] == finite_diff @test diff_backend!(finite_diff) == finite_diff @test diff_backend() == finite_diff @@ -65,6 +65,42 @@ using LinearAlgebra: Diagonal, dot diff_backend!(fd51) end + using ReverseDiff + + reverse_diff = Manifolds.ReverseDiffBackend() + @testset "ReverseDiff" begin + @test diff_backend() isa Manifolds.FiniteDifferencesBackend + @test length(diff_backends()) == 4 + @test diff_backends()[3] == reverse_diff + + @test diff_backend!(reverse_diff) == reverse_diff + @test diff_backend() == reverse_diff + @test diff_backend!(fd51) isa Manifolds.FiniteDifferencesBackend + @test diff_backend() isa Manifolds.FiniteDifferencesBackend + + diff_backend!(reverse_diff) + @test diff_backend() == reverse_diff + diff_backend!(fd51) + end + + using Zygote: Zygote + + zygote_diff = Manifolds.ZygoteDiffBackend() + @testset "Zygote" begin + @test diff_backend() isa Manifolds.FiniteDifferencesBackend + @test length(diff_backends()) == 5 + @test diff_backends()[5] == zygote_diff + + @test diff_backend!(zygote_diff) == zygote_diff + @test diff_backend() == zygote_diff + @test diff_backend!(fd51) isa Manifolds.FiniteDifferencesBackend + @test diff_backend() isa Manifolds.FiniteDifferencesBackend + + diff_backend!(zygote_diff) + @test diff_backend() == zygote_diff + diff_backend!(fd51) + end + @testset "gradient" begin diff_backend!(fd51) r2 = Euclidean(2) @@ -74,11 +110,11 @@ using LinearAlgebra: Diagonal, dot f2(x) = 3 * x[1] * x[2] + x[2]^3 @testset "Inference" begin - v = [-1.0, -1.0] + X = [-1.0, -1.0] @test (@inferred _derivative(c1, 0.0, Manifolds.ForwardDiffBackend())) ≈ [1.0, 0.0] - @test (@inferred _derivative!(c1, v, 0.0, Manifolds.ForwardDiffBackend())) === v - @test v ≈ [1.0, 0.0] + @test (@inferred _derivative!(c1, X, 0.0, Manifolds.ForwardDiffBackend())) === X + @test X ≈ [1.0, 0.0] @test (@inferred _derivative(c1, 0.0, finite_diff)) ≈ [1.0, 0.0] @test (@inferred _gradient(f1, [1.0, -1.0], finite_diff)) ≈ [1.0, -2.0] @@ -87,12 +123,15 @@ using LinearAlgebra: Diagonal, dot @testset for backend in [fd51, fwd_diff, finite_diff] diff_backend!(backend) @test _derivative(c1, 0.0) ≈ [1.0, 0.0] - v = [-1.0, -1.0] - @test _derivative!(c1, v, 0.0) === v - @test isapprox(v, [1.0, 0.0]) + X = [-1.0, -1.0] + @test _derivative!(c1, X, 0.0) === X + @test isapprox(X, [1.0, 0.0]) + end + @testset for backend in [fd51, fwd_diff, finite_diff, reverse_diff, zygote_diff] + X = [-1.0, -1.0] @test _gradient(f1, [1.0, -1.0]) ≈ [1.0, -2.0] - @test _gradient!(f1, v, [1.0, -1.0]) === v - @test v ≈ [1.0, -2.0] + @test _gradient!(f1, X, [1.0, -1.0]) === X + @test X ≈ [1.0, -2.0] end diff_backend!(Manifolds.NoneDiffBackend()) @testset for backend in [fd51, Manifolds.ForwardDiffBackend()] From 91634807265d1da5fe38b3486122bcc1d1b7b043 Mon Sep 17 00:00:00 2001 From: Mateusz Baran Date: Mon, 20 Sep 2021 14:40:50 +0200 Subject: [PATCH 02/10] Add Zygote test dependency --- Project.toml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 46cfe7b91b..de91344da7 100644 --- a/Project.toml +++ b/Project.toml @@ -62,6 +62,7 @@ RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" VisualRegressionTests = "34922c18-7c2a-561c-bac1-01e79b2c4c92" +Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] -test = ["Test", "Colors", "DoubleFloats", "FiniteDiff", "ForwardDiff", "Gtk", "ImageIO", "ImageMagick", "OrdinaryDiffEq", "NLsolve", "Plots", "PyPlot", "Quaternions", "QuartzImageIO", "RecipesBase", "ReverseDiff"] +test = ["Test", "Colors", "DoubleFloats", "FiniteDiff", "ForwardDiff", "Gtk", "ImageIO", "ImageMagick", "OrdinaryDiffEq", "NLsolve", "Plots", "PyPlot", "Quaternions", "QuartzImageIO", "RecipesBase", "ReverseDiff", "Zygote"] From 7849fc0a3057c590df189f184bc2ac3f6fa67cde Mon Sep 17 00:00:00 2001 From: Mateusz Baran Date: Mon, 20 Sep 2021 19:45:41 +0200 Subject: [PATCH 03/10] bump ambiguity limit because of Zygote --- test/ambiguities.jl | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/test/ambiguities.jl b/test/ambiguities.jl index 418d801665..57988178d2 100644 --- a/test/ambiguities.jl +++ b/test/ambiguities.jl @@ -4,12 +4,26 @@ # Interims solution until we follow what was proposed in # https://discourse.julialang.org/t/avoid-ambiguities-with-individual-number-element-identity/62465/2 fmbs = filter(x -> !any(has_type_in_signature.(x, Identity)), mbs) - @test length(fmbs) <= 20 + FMBS_LIMIT = 20 + @test length(fmbs) <= FMBS_LIMIT + if length(fmbs) > FMBS_LIMIT + for amb in fmbs + println(amb) + println() + end + end ms = Test.detect_ambiguities(Manifolds) # Interims solution until we follow what was proposed in # https://discourse.julialang.org/t/avoid-ambiguities-with-individual-number-element-identity/62465/2 fms = filter(x -> !any(has_type_in_signature.(x, Identity)), ms) - @test length(fms) <= 17 + FMS_LIMIT = 21 + if length(fms) > FMS_LIMIT + for amb in fms + println(amb) + println() + end + end + @test length(fms) <= FMS_LIMIT # this test takes way too long to perform regularly # @test length(our_base_ambiguities()) <= 4 else From 1ed46db14f8152c3235cfd40e789703ca57ff032 Mon Sep 17 00:00:00 2001 From: Mateusz Baran Date: Mon, 20 Sep 2021 23:02:03 +0200 Subject: [PATCH 04/10] fix tests and Zygote backend --- src/differentiation/zygote.jl | 4 ++-- test/differentiation.jl | 1 + 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/differentiation/zygote.jl b/src/differentiation/zygote.jl index f8a2e70d16..ffe180113f 100644 --- a/src/differentiation/zygote.jl +++ b/src/differentiation/zygote.jl @@ -1,11 +1,11 @@ struct ZygoteDiffBackend <: AbstractDiffBackend end function Manifolds._gradient(f, p, ::ZygoteDiffBackend) - return Zygote.gradient(f, p) + return Zygote.gradient(f, p)[1] end function Manifolds._gradient!(f, X, p, ::ZygoteDiffBackend) - return Zygote.gradient!(X, f, p) + return copyto!(X, Zygote.gradient(f, p)[1]) end push!(Manifolds._diff_backends, ZygoteDiffBackend()) diff --git a/test/differentiation.jl b/test/differentiation.jl index b060d67962..cc0d2361ed 100644 --- a/test/differentiation.jl +++ b/test/differentiation.jl @@ -128,6 +128,7 @@ using LinearAlgebra: Diagonal, dot @test isapprox(X, [1.0, 0.0]) end @testset for backend in [fd51, fwd_diff, finite_diff, reverse_diff, zygote_diff] + diff_backend!(backend) X = [-1.0, -1.0] @test _gradient(f1, [1.0, -1.0]) ≈ [1.0, -2.0] @test _gradient!(f1, X, [1.0, -1.0]) === X From cf00e220a9d1c1c73b9b4f255e5c42be27b5ba7b Mon Sep 17 00:00:00 2001 From: Mateusz Baran Date: Tue, 21 Sep 2021 08:49:40 +0200 Subject: [PATCH 05/10] bump Julia to 1.5 --- .github/workflows/ci.yml | 2 +- Project.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index efd106f843..2e3a0a6108 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -11,7 +11,7 @@ jobs: runs-on: ${{ matrix.os }} strategy: matrix: - julia-version: ["1.4", "1.6"] + julia-version: ["1.5", "1.6", "~1.7.0-0"] os: [ubuntu-latest, macOS-latest] steps: - uses: actions/checkout@v2 diff --git a/Project.toml b/Project.toml index de91344da7..4d06e4f596 100644 --- a/Project.toml +++ b/Project.toml @@ -42,7 +42,7 @@ SimpleWeightedGraphs = "1" SpecialFunctions = "0.8, 0.9, 0.10, 1.0" StaticArrays = "1.0" StatsBase = "0.32, 0.33" -julia = "1.4" +julia = "1.5" [extras] Colors = "5ae59095-9a9b-59fe-a467-6f913c188581" From c723345258687723479167c18680a683b29e4c5e Mon Sep 17 00:00:00 2001 From: Mateusz Baran Date: Tue, 21 Sep 2021 12:31:30 +0200 Subject: [PATCH 06/10] fixing some issues on Julia 1.7-rc1 --- src/nlsolve.jl | 1 + test/groups/special_euclidean.jl | 1 + test/manifolds/rotations.jl | 2 +- 3 files changed, 3 insertions(+), 1 deletion(-) diff --git a/src/nlsolve.jl b/src/nlsolve.jl index 957bac52a5..e9cd460f9d 100644 --- a/src/nlsolve.jl +++ b/src/nlsolve.jl @@ -55,6 +55,7 @@ function _inverse_retract_nlsolve( retract!(M, F, p, project(M, p, X), retraction; kwargs...) project_point && project!(M, q, q) F .-= q + project_point && project!(M, F, F) return F end isdefined(Manifolds, :NLsolve) || diff --git a/test/groups/special_euclidean.jl b/test/groups/special_euclidean.jl index eedbab4575..ca8730ca97 100644 --- a/test/groups/special_euclidean.jl +++ b/test/groups/special_euclidean.jl @@ -78,6 +78,7 @@ Random.seed!(10) X_pts; test_diff=true, diff_convs=[(), (LeftAction(),), (RightAction(),)], + atol=1e-9, ) end end diff --git a/test/manifolds/rotations.jl b/test/manifolds/rotations.jl index 66701f17b0..c9055c82c4 100644 --- a/test/manifolds/rotations.jl +++ b/test/manifolds/rotations.jl @@ -116,7 +116,7 @@ include("../utils.jl") point_distributions=[ptd], tvector_distributions=[tvd], basis_types_to_from=basis_types, - exp_log_atol_multiplier=20, + exp_log_atol_multiplier=80, retraction_atol_multiplier=12, test_inplace=true, ) From 5d75e6ebc6b05a63c7c54937bf478e41bdecc651 Mon Sep 17 00:00:00 2001 From: Mateusz Baran Date: Tue, 21 Sep 2021 19:17:28 +0200 Subject: [PATCH 07/10] more fixing for Julia 1.7 --- src/nlsolve.jl | 1 - test/approx_inverse_retraction.jl | 4 +++- test/manifolds/power_manifold.jl | 2 +- 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/src/nlsolve.jl b/src/nlsolve.jl index e9cd460f9d..957bac52a5 100644 --- a/src/nlsolve.jl +++ b/src/nlsolve.jl @@ -55,7 +55,6 @@ function _inverse_retract_nlsolve( retract!(M, F, p, project(M, p, X), retraction; kwargs...) project_point && project!(M, q, q) F .-= q - project_point && project!(M, F, F) return F end isdefined(Manifolds, :NLsolve) || diff --git a/test/approx_inverse_retraction.jl b/test/approx_inverse_retraction.jl index 00c41360e5..5f305a1d4b 100644 --- a/test/approx_inverse_retraction.jl +++ b/test/approx_inverse_retraction.jl @@ -3,6 +3,8 @@ using LinearAlgebra include("utils.jl") +Random.seed!(10) + @testset "approximate inverse retractions" begin @testset "NLsolveInverseRetraction" begin @testset "constructor" begin @@ -62,7 +64,7 @@ include("utils.jl") NLsolveInverseRetraction(ProjectionRetraction(), X0; project_point=true) X = inverse_retract(M, p, q, inv_retr_method) @test is_vector(M, p, X; atol=1e-9) - @test X ≈ X_exp + @test X ≈ X_exp atol=1e-8 @test_throws OutOfInjectivityRadiusError inverse_retract( M, p, diff --git a/test/manifolds/power_manifold.jl b/test/manifolds/power_manifold.jl index bc7d03b95b..4a58fcabfc 100644 --- a/test/manifolds/power_manifold.jl +++ b/test/manifolds/power_manifold.jl @@ -240,7 +240,7 @@ end rand_tvector_atol_multiplier=6.0, retraction_atol_multiplier=12, is_tangent_atol_multiplier=12.0, - exp_log_atol_multiplier=2 * prod(power_dimensions(Ms2)), + exp_log_atol_multiplier=3 * prod(power_dimensions(Ms2)), test_inplace=true, ) end From 3b7396d3329073a31328220f87833960f700f0bc Mon Sep 17 00:00:00 2001 From: Mateusz Baran Date: Tue, 21 Sep 2021 19:19:09 +0200 Subject: [PATCH 08/10] formatting --- test/approx_inverse_retraction.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/approx_inverse_retraction.jl b/test/approx_inverse_retraction.jl index 5f305a1d4b..83fa282874 100644 --- a/test/approx_inverse_retraction.jl +++ b/test/approx_inverse_retraction.jl @@ -64,7 +64,7 @@ Random.seed!(10) NLsolveInverseRetraction(ProjectionRetraction(), X0; project_point=true) X = inverse_retract(M, p, q, inv_retr_method) @test is_vector(M, p, X; atol=1e-9) - @test X ≈ X_exp atol=1e-8 + @test X ≈ X_exp atol = 1e-8 @test_throws OutOfInjectivityRadiusError inverse_retract( M, p, From 3c47350fd3ce87550757b81020aeb00b6dc5b7d7 Mon Sep 17 00:00:00 2001 From: Ronny Bergmann Date: Wed, 22 Sep 2021 08:59:28 +0200 Subject: [PATCH 09/10] reduce tangent vector length in a test since the approximation only works very locally (and they changed the default random number generator) --- test/manifolds/stiefel.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/manifolds/stiefel.jl b/test/manifolds/stiefel.jl index e9d89ee7c6..9d494bbb4c 100644 --- a/test/manifolds/stiefel.jl +++ b/test/manifolds/stiefel.jl @@ -311,7 +311,7 @@ using Manifolds: default_metric_dispatch M4 = MetricManifold(Stiefel(10, 2), CanonicalMetric()) p = Matrix{Float64}(I, 10, 2) Random.seed!(42) - Z = project(base_manifold(M4), p, randn(size(p))) + Z = project(base_manifold(M4), p, 0.5 .* randn(size(p))) s = exp(M4, p, Z) Z2 = log(M4, p, s) @test isapprox(M4, p, Z, Z2) From 07c905eab7de9aa0ca827ce6977ce7c5fc4ae980 Mon Sep 17 00:00:00 2001 From: Mateusz Baran Date: Wed, 22 Sep 2021 17:27:10 +0200 Subject: [PATCH 10/10] Update Project.toml Co-authored-by: Ronny Bergmann --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 4d06e4f596..e91fcf301f 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Manifolds" uuid = "1cead3c2-87b3-11e9-0ccd-23c62b72b94e" authors = ["Seth Axen ", "Mateusz Baran ", "Ronny Bergmann ", "Antoine Levitt "] -version = "0.6.8" +version = "0.6.9" [deps] Colors = "5ae59095-9a9b-59fe-a467-6f913c188581"