diff --git a/test/differentiation.jl b/test/differentiation.jl index 5bc49d8c67..e3e1999325 100644 --- a/test/differentiation.jl +++ b/test/differentiation.jl @@ -91,6 +91,42 @@ using LinearAlgebra: Diagonal, dot set_default_differential_backend!(fd51) end + using ReverseDiff + + reverse_diff = Manifolds.ReverseDiffBackend() + @testset "ReverseDiff" begin + @test diff_backend() isa Manifolds.FiniteDifferencesBackend + @test length(diff_backends()) == 4 + @test diff_backends()[3] == reverse_diff + + @test diff_backend!(reverse_diff) == reverse_diff + @test diff_backend() == reverse_diff + @test diff_backend!(fd51) isa Manifolds.FiniteDifferencesBackend + @test diff_backend() isa Manifolds.FiniteDifferencesBackend + + diff_backend!(reverse_diff) + @test diff_backend() == reverse_diff + diff_backend!(fd51) + end + + using Zygote: Zygote + + zygote_diff = Manifolds.ZygoteDiffBackend() + @testset "Zygote" begin + @test diff_backend() isa Manifolds.FiniteDifferencesBackend + @test length(diff_backends()) == 5 + @test diff_backends()[5] == zygote_diff + + @test diff_backend!(zygote_diff) == zygote_diff + @test diff_backend() == zygote_diff + @test diff_backend!(fd51) isa Manifolds.FiniteDifferencesBackend + @test diff_backend() isa Manifolds.FiniteDifferencesBackend + + diff_backend!(zygote_diff) + @test diff_backend() == zygote_diff + diff_backend!(fd51) + end + @testset "gradient" begin set_default_differential_backend!(fd51) r2 = Euclidean(2)