Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for Zygote and ReverseDiff gradients #427

Merged
merged 16 commits into from
Sep 23, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ jobs:
runs-on: ${{ matrix.os }}
strategy:
matrix:
julia-version: ["1.4", "1.6"]
julia-version: ["1.5", "1.6", "~1.7.0-0"]
os: [ubuntu-latest, macOS-latest]
steps:
- uses: actions/checkout@v2
Expand Down
7 changes: 4 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Manifolds"
uuid = "1cead3c2-87b3-11e9-0ccd-23c62b72b94e"
authors = ["Seth Axen <[email protected]>", "Mateusz Baran <[email protected]>", "Ronny Bergmann <[email protected]>", "Antoine Levitt <[email protected]>"]
version = "0.6.8"
version = "0.6.9"

[deps]
Colors = "5ae59095-9a9b-59fe-a467-6f913c188581"
Expand Down Expand Up @@ -42,7 +42,7 @@ SimpleWeightedGraphs = "1"
SpecialFunctions = "0.8, 0.9, 0.10, 1.0"
StaticArrays = "1.0"
StatsBase = "0.32, 0.33"
julia = "1.4"
julia = "1.5"

[extras]
Colors = "5ae59095-9a9b-59fe-a467-6f913c188581"
Expand All @@ -62,6 +62,7 @@ RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
VisualRegressionTests = "34922c18-7c2a-561c-bac1-01e79b2c4c92"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Test", "Colors", "DoubleFloats", "FiniteDiff", "ForwardDiff", "Gtk", "ImageIO", "ImageMagick", "OrdinaryDiffEq", "NLsolve", "Plots", "PyPlot", "Quaternions", "QuartzImageIO", "RecipesBase", "ReverseDiff"]
test = ["Test", "Colors", "DoubleFloats", "FiniteDiff", "ForwardDiff", "Gtk", "ImageIO", "ImageMagick", "OrdinaryDiffEq", "NLsolve", "Plots", "PyPlot", "Quaternions", "QuartzImageIO", "RecipesBase", "ReverseDiff", "Zygote"]
19 changes: 15 additions & 4 deletions src/Manifolds.jl
Original file line number Diff line number Diff line change
Expand Up @@ -157,8 +157,8 @@ using RecursiveArrayTools: ArrayPartition
include("utils.jl")

include("product_representations.jl")
include("differentiation.jl")
include("riemannian_diff.jl")
include("differentiation/differentiation.jl")
include("differentiation/riemannian_diff.jl")

# Main Meta Manifolds
include("manifolds/ConnectionManifold.jl")
Expand Down Expand Up @@ -284,12 +284,12 @@ end
function __init__()
@require FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41" begin
using .FiniteDiff
include("finite_diff.jl")
include("differentiation/finite_diff.jl")
end

@require ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" begin
using .ForwardDiff
include("forward_diff.jl")
include("differentiation/forward_diff.jl")
end

@require OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed" begin
Expand All @@ -302,6 +302,11 @@ function __init__()
include("nlsolve.jl")
end

@require ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" begin
using .ReverseDiff: ReverseDiff
include("differentiation/reverse_diff.jl")
end

@require Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" begin
using .Test: Test
include("tests/tests_general.jl")
Expand Down Expand Up @@ -332,6 +337,12 @@ function __init__()
include("recipes.jl")
end
end

@require Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" begin
using .Zygote: Zygote
include("differentiation/zygote.jl")
end

return nothing
end

Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
11 changes: 11 additions & 0 deletions src/differentiation/reverse_diff.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
struct ReverseDiffBackend <: AbstractDiffBackend end

function Manifolds._gradient(f, p, ::ReverseDiffBackend)
return ReverseDiff.gradient(f, p)
end

function Manifolds._gradient!(f, X, p, ::ReverseDiffBackend)
return ReverseDiff.gradient!(X, f, p)
end

push!(Manifolds._diff_backends, ReverseDiffBackend())
File renamed without changes.
11 changes: 11 additions & 0 deletions src/differentiation/zygote.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
struct ZygoteDiffBackend <: AbstractDiffBackend end

function Manifolds._gradient(f, p, ::ZygoteDiffBackend)
return Zygote.gradient(f, p)[1]
end

function Manifolds._gradient!(f, X, p, ::ZygoteDiffBackend)
return copyto!(X, Zygote.gradient(f, p)[1])
end

push!(Manifolds._diff_backends, ZygoteDiffBackend())
18 changes: 16 additions & 2 deletions test/ambiguities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,26 @@
# Interims solution until we follow what was proposed in
# https://discourse.julialang.org/t/avoid-ambiguities-with-individual-number-element-identity/62465/2
fmbs = filter(x -> !any(has_type_in_signature.(x, Identity)), mbs)
@test length(fmbs) <= 20
FMBS_LIMIT = 20
@test length(fmbs) <= FMBS_LIMIT
if length(fmbs) > FMBS_LIMIT
for amb in fmbs
println(amb)
println()
end
end
ms = Test.detect_ambiguities(Manifolds)
# Interims solution until we follow what was proposed in
# https://discourse.julialang.org/t/avoid-ambiguities-with-individual-number-element-identity/62465/2
fms = filter(x -> !any(has_type_in_signature.(x, Identity)), ms)
@test length(fms) <= 17
FMS_LIMIT = 21
if length(fms) > FMS_LIMIT
for amb in fms
println(amb)
println()
end
end
@test length(fms) <= FMS_LIMIT
# this test takes way too long to perform regularly
# @test length(our_base_ambiguities()) <= 4
else
Expand Down
4 changes: 3 additions & 1 deletion test/approx_inverse_retraction.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ using LinearAlgebra

include("utils.jl")

Random.seed!(10)

@testset "approximate inverse retractions" begin
@testset "NLsolveInverseRetraction" begin
@testset "constructor" begin
Expand Down Expand Up @@ -62,7 +64,7 @@ include("utils.jl")
NLsolveInverseRetraction(ProjectionRetraction(), X0; project_point=true)
X = inverse_retract(M, p, q, inv_retr_method)
@test is_vector(M, p, X; atol=1e-9)
@test X ≈ X_exp
@test X ≈ X_exp atol = 1e-8
@test_throws OutOfInjectivityRadiusError inverse_retract(
M,
p,
Expand Down
64 changes: 52 additions & 12 deletions test/differentiation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ using LinearAlgebra: Diagonal, dot
fd51 = Manifolds.FiniteDifferencesBackend()
@testset "diff_backend" begin
@test diff_backend() isa Manifolds.FiniteDifferencesBackend
@test length(diff_backends()) == 2
@test length(diff_backends()) == 3
@test diff_backends()[1] isa Manifolds.FiniteDifferencesBackend

@test length(fd51.method.grid) == 5
Expand All @@ -33,7 +33,7 @@ using LinearAlgebra: Diagonal, dot
fwd_diff = Manifolds.ForwardDiffBackend()
@testset "ForwardDiff" begin
@test diff_backend() isa Manifolds.FiniteDifferencesBackend
@test length(diff_backends()) == 2
@test length(diff_backends()) == 3
@test diff_backends()[1] isa Manifolds.FiniteDifferencesBackend
@test diff_backends()[2] == fwd_diff

Expand All @@ -52,8 +52,8 @@ using LinearAlgebra: Diagonal, dot
finite_diff = Manifolds.FiniteDiffBackend()
@testset "FiniteDiff" begin
@test diff_backend() isa Manifolds.FiniteDifferencesBackend
@test length(diff_backends()) == 3
@test diff_backends()[3] == finite_diff
@test length(diff_backends()) == 4
@test diff_backends()[4] == finite_diff

@test diff_backend!(finite_diff) == finite_diff
@test diff_backend() == finite_diff
Expand All @@ -65,6 +65,42 @@ using LinearAlgebra: Diagonal, dot
diff_backend!(fd51)
end

using ReverseDiff

reverse_diff = Manifolds.ReverseDiffBackend()
@testset "ReverseDiff" begin
@test diff_backend() isa Manifolds.FiniteDifferencesBackend
@test length(diff_backends()) == 4
@test diff_backends()[3] == reverse_diff

@test diff_backend!(reverse_diff) == reverse_diff
@test diff_backend() == reverse_diff
@test diff_backend!(fd51) isa Manifolds.FiniteDifferencesBackend
@test diff_backend() isa Manifolds.FiniteDifferencesBackend

diff_backend!(reverse_diff)
@test diff_backend() == reverse_diff
diff_backend!(fd51)
end

using Zygote: Zygote

zygote_diff = Manifolds.ZygoteDiffBackend()
@testset "Zygote" begin
@test diff_backend() isa Manifolds.FiniteDifferencesBackend
@test length(diff_backends()) == 5
@test diff_backends()[5] == zygote_diff

@test diff_backend!(zygote_diff) == zygote_diff
@test diff_backend() == zygote_diff
@test diff_backend!(fd51) isa Manifolds.FiniteDifferencesBackend
@test diff_backend() isa Manifolds.FiniteDifferencesBackend

diff_backend!(zygote_diff)
@test diff_backend() == zygote_diff
diff_backend!(fd51)
end

@testset "gradient" begin
diff_backend!(fd51)
r2 = Euclidean(2)
Expand All @@ -74,11 +110,11 @@ using LinearAlgebra: Diagonal, dot
f2(x) = 3 * x[1] * x[2] + x[2]^3

@testset "Inference" begin
v = [-1.0, -1.0]
X = [-1.0, -1.0]
@test (@inferred _derivative(c1, 0.0, Manifolds.ForwardDiffBackend())) ≈
[1.0, 0.0]
@test (@inferred _derivative!(c1, v, 0.0, Manifolds.ForwardDiffBackend())) === v
@test v ≈ [1.0, 0.0]
@test (@inferred _derivative!(c1, X, 0.0, Manifolds.ForwardDiffBackend())) === X
@test X ≈ [1.0, 0.0]

@test (@inferred _derivative(c1, 0.0, finite_diff)) ≈ [1.0, 0.0]
@test (@inferred _gradient(f1, [1.0, -1.0], finite_diff)) ≈ [1.0, -2.0]
Expand All @@ -87,12 +123,16 @@ using LinearAlgebra: Diagonal, dot
@testset for backend in [fd51, fwd_diff, finite_diff]
diff_backend!(backend)
@test _derivative(c1, 0.0) ≈ [1.0, 0.0]
v = [-1.0, -1.0]
@test _derivative!(c1, v, 0.0) === v
@test isapprox(v, [1.0, 0.0])
X = [-1.0, -1.0]
@test _derivative!(c1, X, 0.0) === X
@test isapprox(X, [1.0, 0.0])
end
@testset for backend in [fd51, fwd_diff, finite_diff, reverse_diff, zygote_diff]
diff_backend!(backend)
X = [-1.0, -1.0]
@test _gradient(f1, [1.0, -1.0]) ≈ [1.0, -2.0]
@test _gradient!(f1, v, [1.0, -1.0]) === v
@test v ≈ [1.0, -2.0]
@test _gradient!(f1, X, [1.0, -1.0]) === X
@test X ≈ [1.0, -2.0]
end
diff_backend!(Manifolds.NoneDiffBackend())
@testset for backend in [fd51, Manifolds.ForwardDiffBackend()]
Expand Down
2 changes: 2 additions & 0 deletions test/groups/special_euclidean.jl
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ Random.seed!(10)
X_pts;
test_diff=true,
diff_convs=[(), (LeftAction(),), (RightAction(),)],
atol=1e-9,
)
end
end
Expand Down Expand Up @@ -128,6 +129,7 @@ Random.seed!(10)
test_diff=true,
test_lie_bracket=true,
diff_convs=[(), (LeftAction(),), (RightAction(),)],
atol=1e-9,
)
# specific affine tests
p = copy(G, pts[1])
Expand Down
2 changes: 1 addition & 1 deletion test/manifolds/power_manifold.jl
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ end
rand_tvector_atol_multiplier=6.0,
retraction_atol_multiplier=12,
is_tangent_atol_multiplier=12.0,
exp_log_atol_multiplier=2 * prod(power_dimensions(Ms2)),
exp_log_atol_multiplier=3 * prod(power_dimensions(Ms2)),
test_inplace=true,
)
end
Expand Down
2 changes: 1 addition & 1 deletion test/manifolds/rotations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ include("../utils.jl")
point_distributions=[ptd],
tvector_distributions=[tvd],
basis_types_to_from=basis_types,
exp_log_atol_multiplier=20,
exp_log_atol_multiplier=250,
retraction_atol_multiplier=12,
test_inplace=true,
)
Expand Down
2 changes: 1 addition & 1 deletion test/manifolds/stiefel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,7 @@ using Manifolds: default_metric_dispatch
M4 = MetricManifold(Stiefel(10, 2), CanonicalMetric())
p = Matrix{Float64}(I, 10, 2)
Random.seed!(42)
Z = project(base_manifold(M4), p, randn(size(p)))
Z = project(base_manifold(M4), p, 0.2 .* randn(size(p)))
s = exp(M4, p, Z)
Z2 = log(M4, p, s)
@test isapprox(M4, p, Z, Z2)
Expand Down