From aca6d5c2cf66119873b1727c6d1241847602decc Mon Sep 17 00:00:00 2001 From: Ronny Bergmann Date: Wed, 22 Sep 2021 19:46:59 +0200 Subject: [PATCH] Add a proper working default. --- src/differentiation/riemannian_diff.jl | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/src/differentiation/riemannian_diff.jl b/src/differentiation/riemannian_diff.jl index d696daba6f..d24f02f96d 100644 --- a/src/differentiation/riemannian_diff.jl +++ b/src/differentiation/riemannian_diff.jl @@ -67,12 +67,12 @@ In the tangent space itself, this backend then employs an (Euclidean) # Constructor - TangentDiffBackend() + TangentDiffBackend(diff_backend) + +where `diff_backend` is an [`AbstractDiffBackend`](@ref) to be used on the tangent space. With the keyword arguments -* `diff_backend` an [`AbstractDiffBackend`](@ref) to be used on the tangent space - (by defaut [`FiniteDiffBackend`](@ref) is used) * `retraction` an [`AbstractRetractionMethod`](@ref) ([`ExponentialRetraction`]('ref) by default) * `inverse_retraction` an [`AbstractInverseRetractionMethod`](@ref) ([`LogarithmicInverseRetraction`]('ref) by default) * `basis` an [`AbstractBasis`](@ref) ([`DefaultOrthogonalBasis`]('ref) by default) @@ -88,8 +88,7 @@ struct TangentDiffBackend{ inverse_retraction::TIR basis::TB end -function TangentDiffBackend(; - diff_backend::TAD=FiniteDiffBackend(), +function TangentDiffBackend(diff_backend::TAD; retraction::TR = ExponentialRetraction(), inverse_retraction::TIR = LogarithmicInverseRetraction(), basis::TB = DefaultOrthonormalBasis() @@ -198,7 +197,7 @@ globally default differentiation backend for calculating gradients. [`Manifolds.gradient(::AbstractManifold, ::Any, ::Any, ::AbstractRiemannianDiffBackend)`](@ref) """ -const _current_rgradient_backend = CurrentRiemannianDiffBackend(TangentDiffBackend()) +const _current_rgradient_backend = CurrentRiemannianDiffBackend(TangentDiffBackend(FiniteDifferencesBackend())) """ _current_rdifferential_backend @@ -210,7 +209,7 @@ globally default differentiation backend for calculating differentials. [`Manifolds.differential`](@ref) """ -const _current_rdifferential_backend = CurrentRiemannianDiffBackend(TangentDiffBackend()) +const _current_rdifferential_backend = CurrentRiemannianDiffBackend(TangentDiffBackend(FiniteDifferencesBackend())) """ rgradient_backend() -> AbstractRiemannianDiffBackend @@ -246,7 +245,7 @@ function rdifferential_backend!(backend::AbstractRiemannianDiffBackend) return backend end -""" +@doc raw""" RiemannianProjectionBackend <: AbstractRiemannianDiffBackend This backend computes the differentiation in the embedding, which is currently limited