Skip to content

Commit

Permalink
Support for Zygote and ReverseDiff gradients
Browse files Browse the repository at this point in the history
  • Loading branch information
mateuszbaran committed Sep 20, 2021
1 parent bd2b425 commit 7ed57da
Show file tree
Hide file tree
Showing 9 changed files with 89 additions and 17 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Manifolds"
uuid = "1cead3c2-87b3-11e9-0ccd-23c62b72b94e"
authors = ["Seth Axen <[email protected]>", "Mateusz Baran <[email protected]>", "Ronny Bergmann <[email protected]>", "Antoine Levitt <[email protected]>"]
version = "0.6.7"
version = "0.6.8"

[deps]
Colors = "5ae59095-9a9b-59fe-a467-6f913c188581"
Expand Down
19 changes: 15 additions & 4 deletions src/Manifolds.jl
Original file line number Diff line number Diff line change
Expand Up @@ -158,8 +158,8 @@ using RecursiveArrayTools: ArrayPartition
include("utils.jl")

include("product_representations.jl")
include("differentiation.jl")
include("riemannian_diff.jl")
include("differentiation/differentiation.jl")
include("differentiation/riemannian_diff.jl")

# Main Meta Manifolds
include("manifolds/ConnectionManifold.jl")
Expand Down Expand Up @@ -285,12 +285,12 @@ end
function __init__()
@require FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41" begin
using .FiniteDiff
include("finite_diff.jl")
include("differentiation/finite_diff.jl")
end

@require ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" begin
using .ForwardDiff
include("forward_diff.jl")
include("differentiation/forward_diff.jl")
end

@require OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed" begin
Expand All @@ -303,6 +303,11 @@ function __init__()
include("nlsolve.jl")
end

@require ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" begin
using .ReverseDiff: ReverseDiff
include("differentiation/reverse_diff.jl")
end

@require Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" begin
using .Test: Test
include("tests/tests_general.jl")
Expand Down Expand Up @@ -333,6 +338,12 @@ function __init__()
include("recipes.jl")
end
end

@require Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" begin
using .Zygote: Zygote
include("differentiation/zygote.jl")
end

return nothing
end

Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
11 changes: 11 additions & 0 deletions src/differentiation/reverse_diff.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
struct ReverseDiffBackend <: AbstractDiffBackend end

function Manifolds._gradient(f, p, ::ReverseDiffBackend)
return ReverseDiff.gradient(f, p)
end

function Manifolds._gradient!(f, X, p, ::ReverseDiffBackend)
return ReverseDiff.gradient!(X, f, p)
end

push!(Manifolds._diff_backends, ReverseDiffBackend())
File renamed without changes.
11 changes: 11 additions & 0 deletions src/differentiation/zygote.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
struct ZygoteDiffBackend <: AbstractDiffBackend end

function Manifolds._gradient(f, p, ::ZygoteDiffBackend)
return Zygote.gradient(f, p)
end

function Manifolds._gradient!(f, X, p, ::ZygoteDiffBackend)
return Zygote.gradient!(X, f, p)
end

push!(Manifolds._diff_backends, ZygoteDiffBackend())
63 changes: 51 additions & 12 deletions test/differentiation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ using LinearAlgebra: Diagonal, dot
fd51 = Manifolds.FiniteDifferencesBackend()
@testset "diff_backend" begin
@test diff_backend() isa Manifolds.FiniteDifferencesBackend
@test length(diff_backends()) == 2
@test length(diff_backends()) == 3
@test diff_backends()[1] isa Manifolds.FiniteDifferencesBackend

@test length(fd51.method.grid) == 5
Expand All @@ -33,7 +33,7 @@ using LinearAlgebra: Diagonal, dot
fwd_diff = Manifolds.ForwardDiffBackend()
@testset "ForwardDiff" begin
@test diff_backend() isa Manifolds.FiniteDifferencesBackend
@test length(diff_backends()) == 2
@test length(diff_backends()) == 3
@test diff_backends()[1] isa Manifolds.FiniteDifferencesBackend
@test diff_backends()[2] == fwd_diff

Expand All @@ -52,8 +52,8 @@ using LinearAlgebra: Diagonal, dot
finite_diff = Manifolds.FiniteDiffBackend()
@testset "FiniteDiff" begin
@test diff_backend() isa Manifolds.FiniteDifferencesBackend
@test length(diff_backends()) == 3
@test diff_backends()[3] == finite_diff
@test length(diff_backends()) == 4
@test diff_backends()[4] == finite_diff

@test diff_backend!(finite_diff) == finite_diff
@test diff_backend() == finite_diff
Expand All @@ -65,6 +65,42 @@ using LinearAlgebra: Diagonal, dot
diff_backend!(fd51)
end

using ReverseDiff

reverse_diff = Manifolds.ReverseDiffBackend()
@testset "ReverseDiff" begin
@test diff_backend() isa Manifolds.FiniteDifferencesBackend
@test length(diff_backends()) == 4
@test diff_backends()[3] == reverse_diff

@test diff_backend!(reverse_diff) == reverse_diff
@test diff_backend() == reverse_diff
@test diff_backend!(fd51) isa Manifolds.FiniteDifferencesBackend
@test diff_backend() isa Manifolds.FiniteDifferencesBackend

diff_backend!(reverse_diff)
@test diff_backend() == reverse_diff
diff_backend!(fd51)
end

using Zygote: Zygote

zygote_diff = Manifolds.ZygoteDiffBackend()
@testset "Zygote" begin
@test diff_backend() isa Manifolds.FiniteDifferencesBackend
@test length(diff_backends()) == 5
@test diff_backends()[5] == zygote_diff

@test diff_backend!(zygote_diff) == zygote_diff
@test diff_backend() == zygote_diff
@test diff_backend!(fd51) isa Manifolds.FiniteDifferencesBackend
@test diff_backend() isa Manifolds.FiniteDifferencesBackend

diff_backend!(zygote_diff)
@test diff_backend() == zygote_diff
diff_backend!(fd51)
end

@testset "gradient" begin
diff_backend!(fd51)
r2 = Euclidean(2)
Expand All @@ -74,11 +110,11 @@ using LinearAlgebra: Diagonal, dot
f2(x) = 3 * x[1] * x[2] + x[2]^3

@testset "Inference" begin
v = [-1.0, -1.0]
X = [-1.0, -1.0]
@test (@inferred _derivative(c1, 0.0, Manifolds.ForwardDiffBackend()))
[1.0, 0.0]
@test (@inferred _derivative!(c1, v, 0.0, Manifolds.ForwardDiffBackend())) === v
@test v [1.0, 0.0]
@test (@inferred _derivative!(c1, X, 0.0, Manifolds.ForwardDiffBackend())) === X
@test X [1.0, 0.0]

@test (@inferred _derivative(c1, 0.0, finite_diff)) [1.0, 0.0]
@test (@inferred _gradient(f1, [1.0, -1.0], finite_diff)) [1.0, -2.0]
Expand All @@ -87,12 +123,15 @@ using LinearAlgebra: Diagonal, dot
@testset for backend in [fd51, fwd_diff, finite_diff]
diff_backend!(backend)
@test _derivative(c1, 0.0) [1.0, 0.0]
v = [-1.0, -1.0]
@test _derivative!(c1, v, 0.0) === v
@test isapprox(v, [1.0, 0.0])
X = [-1.0, -1.0]
@test _derivative!(c1, X, 0.0) === X
@test isapprox(X, [1.0, 0.0])
end
@testset for backend in [fd51, fwd_diff, finite_diff, reverse_diff, zygote_diff]
X = [-1.0, -1.0]
@test _gradient(f1, [1.0, -1.0]) [1.0, -2.0]
@test _gradient!(f1, v, [1.0, -1.0]) === v
@test v [1.0, -2.0]
@test _gradient!(f1, X, [1.0, -1.0]) === X
@test X [1.0, -2.0]
end
diff_backend!(Manifolds.NoneDiffBackend())
@testset for backend in [fd51, Manifolds.ForwardDiffBackend()]
Expand Down

0 comments on commit 7ed57da

Please sign in to comment.