Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Polish the Differential Interfaces #429

Closed
wants to merge 21 commits into from
Closed
Show file tree
Hide file tree
Changes from 12 commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
4376b8b
Squashed commit of the following:
kellertuer Sep 24, 2021
b2f2674
Merge branch 'master' into kellertuer/polish-diff-interfaces
kellertuer Sep 24, 2021
59c0467
first sketch of the ODEExponenitalRetraction
kellertuer Sep 25, 2021
f6bede4
Fixes two typos.
kellertuer Sep 25, 2021
475140e
Fix more typos.
kellertuer Sep 25, 2021
61b4dcb
Move solve_exp_ode with OridnaryDiffEq to differentiation and change …
kellertuer Sep 25, 2021
49ecd3c
fixes a first test.
kellertuer Sep 25, 2021
1885973
Ran Juliaformatter 23.564 times today, just not before the last commit.
kellertuer Sep 25, 2021
2181cff
trying to improve tests?
kellertuer Sep 25, 2021
cde3c8e
Fixes most errors, just some default retractions need to be fixed (st…
kellertuer Sep 26, 2021
e1bf764
reduce errors further.
kellertuer Sep 26, 2021
b112499
reduices errors as far as I am able to do.
kellertuer Sep 27, 2021
f2625fb
Implements my understanding of what the methods should be – error pe…
kellertuer Sep 27, 2021
a6702f6
fix some tests
mateuszbaran Sep 27, 2021
a4a2716
starts seperate ODE tests – but facing a new problem.
kellertuer Sep 28, 2021
5ed5feb
getting again stuck on metrics.
kellertuer Sep 28, 2021
d544ce9
fixing ode tests a bit
mateuszbaran Sep 28, 2021
26ab2ec
Fighting dimensions within the ODE solver.
kellertuer Sep 28, 2021
bd39060
Dimensions fixed, now just the result is wrong :D
kellertuer Sep 28, 2021
44ceb3e
reorganise to get the ODE stuff (a) defined on arbitrary manifolds an…
kellertuer Sep 29, 2021
6ac5d75
Revert to an embedded approach. Maybe.
kellertuer Oct 1, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ version = "0.6.9"
Colors = "5ae59095-9a9b-59fe-a467-6f913c188581"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Einsum = "b7d42ee7-0b51-5a75-98ca-779d3107e4c0"
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
HybridArrays = "1baab800-613f-4b0a-84e4-9cd3431bfbb9"
Kronecker = "2c470bb0-bcc8-11e8-3dad-c9649493f05e"
LightGraphs = "093fc24a-ae57-5d10-9952-331d41423f4d"
Expand All @@ -29,7 +28,6 @@ Colors = "0.12"
Distributions = "0.22.6, 0.23, 0.24, 0.25"
Einsum = "0.4"
FiniteDiff = "2"
FiniteDifferences = "0.12"
HybridArrays = "0.4"
Kronecker = "0.4, 0.5"
LightGraphs = "1"
Expand All @@ -48,6 +46,7 @@ julia = "1.5"
Colors = "5ae59095-9a9b-59fe-a467-6f913c188581"
DoubleFloats = "497a8b3b-efae-58df-a0af-a86822472b78"
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Gtk = "4c0ca9eb-093a-5379-98c5-f87ac0bbbf44"
ImageIO = "82e4d734-157c-48bb-816b-45c225c6df19"
Expand All @@ -65,4 +64,4 @@ VisualRegressionTests = "34922c18-7c2a-561c-bac1-01e79b2c4c92"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Test", "Colors", "DoubleFloats", "FiniteDiff", "ForwardDiff", "Gtk", "ImageIO", "ImageMagick", "OrdinaryDiffEq", "NLsolve", "Plots", "PyPlot", "Quaternions", "QuartzImageIO", "RecipesBase", "ReverseDiff", "Zygote"]
test = ["Test", "Colors", "DoubleFloats", "FiniteDiff", "FiniteDifferences", "ForwardDiff", "Gtk", "ImageIO", "ImageMagick", "OrdinaryDiffEq", "NLsolve", "Plots", "PyPlot", "Quaternions", "QuartzImageIO", "RecipesBase", "ReverseDiff", "Zygote"]
15 changes: 10 additions & 5 deletions src/Manifolds.jl
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,6 @@ import Base:
using Base.Iterators: repeated
using Distributions
using Einsum: @einsum
using FiniteDifferences
using HybridArrays
using Kronecker
using LightGraphs
Expand Down Expand Up @@ -287,14 +286,19 @@ function __init__()
include("differentiation/finite_diff.jl")
end

@require FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" begin
using .FiniteDifferences
include("differentiation/finite_differences.jl")
end

@require ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" begin
using .ForwardDiff
include("differentiation/forward_diff.jl")
end

@require OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed" begin
using .OrdinaryDiffEq: ODEProblem, AutoVern9, Rodas5, solve
include("ode.jl")
include("differentiation/ode.jl")
end

@require NLsolve = "2774e3e8-f4cf-5e23-947b-6d7e65073b56" begin
Expand Down Expand Up @@ -440,6 +444,7 @@ export AbstractRetractionMethod,
PolarRetraction,
ProjectionRetraction,
SoftmaxRetraction,
ODEExponentialRetraction,
PadeRetraction,
ProductRetraction,
PowerRetraction
Expand Down Expand Up @@ -664,9 +669,9 @@ export get_basis,
export AbstractDiffBackend,
AbstractRiemannianDiffBackend,
FiniteDifferencesBackend,
RiemannianONBDiffBackend,
RiemannianProjectionGradientBackend
export diff_backend, diff_backend!, diff_backends
TangentDiffBackend,
RiemannianProjectionBackend
kellertuer marked this conversation as resolved.
Show resolved Hide resolved
export default_differential_backend, set_default_differential_backend!
# atlases and charts
export get_point, get_point!, get_parameters, get_parameters!

Expand Down
109 changes: 41 additions & 68 deletions src/differentiation/differentiation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ struct NoneDiffBackend <: AbstractDiffBackend end

Compute the derivative of a callable `f` at time `t` computed using the given `backend`,
an object of type [`Manifolds.AbstractDiffBackend`](@ref). If the backend is not explicitly
specified, it is obtained using the function [`Manifolds.diff_backend`](@ref).
specified, it is obtained using the function [`default_differential_backend`](@ref).

This function calculates plain Euclidean derivatives, for Riemannian differentiation see
for example [`differential`](@ref Manifolds.differential(::AbstractManifold, ::Any, ::Real, ::AbstractRiemannianDiffBackend)).
Expand All @@ -26,7 +26,9 @@ for example [`differential`](@ref Manifolds.differential(::AbstractManifold, ::A
"""
function _derivative end

function _derivative!(f, X, t, backend::AbstractDiffBackend)
_derivative(f, t) = _derivative(f, t, default_differential_backend())

function _derivative!(f, X, t, backend::AbstractDiffBackend=default_differential_backend())
return copyto!(X, _derivative(f, t, backend))
end

Expand All @@ -35,7 +37,7 @@ end

Compute the gradient of a callable `f` at point `p` computed using the given `backend`,
an object of type [`AbstractDiffBackend`](@ref). If the backend is not explicitly
specified, it is obtained using the function [`diff_backend`](@ref).
specified, it is obtained using the function [`default_differential_backend`](@ref).

This function calculates plain Euclidean gradients, for Riemannian gradient calculation see
for example [`gradient`](@ref Manifolds.gradient(::AbstractManifold, ::Any, ::Any, ::AbstractRiemannianDiffBackend)).
Expand All @@ -47,98 +49,69 @@ for example [`gradient`](@ref Manifolds.gradient(::AbstractManifold, ::Any, ::An
"""
function _gradient end

function _gradient!(f, X, p, backend::AbstractDiffBackend)
_gradient(f, p) = _gradient(f, p, default_differential_backend())

function _gradient!(f, X, p, backend::AbstractDiffBackend=default_differential_backend())
return copyto!(X, _gradient(f, p, backend))
end

"""
_jacobian(f, p[, backend::AbstractDiffBackend])

Compute the jacobian of a callable `f` at point `p` computed using the given `backend`,
an object of type [`AbstractDiffBackend`](@ref). If the backend is not explicitly
specified, it is obtained using the function [`default_differential_backend`](@ref).

This function calculates plain Euclidean gradients, for Riemannian gradient calculation see
for example [`gradient`](@ref Manifolds.gradient(::AbstractManifold, ::Any, ::Any, ::AbstractRiemannianDiffBackend)).

!!! note

Not specifying the backend explicitly will usually result in a type instability
and decreased performance.
"""
function _jacobian end

_jacobian(f, p) = _jacobian(f, p, default_differential_backend())

function _jacobian!(f, X, p, backend::AbstractDiffBackend=default_differential_backend())
return copyto!(X, _jacobian(f, p, backend))
end

"""
CurrentDiffBackend(backend::AbstractDiffBackend)

A mutable struct for storing the current differentiation backend in a global
constant [`_current_diff_backend`](@ref).
constant [`_current_default_differential_backend`](@ref).

# See also

[`AbstractDiffBackend`](@ref), [`diff_backend`](@ref), [`diff_backend!`](@ref)
[`AbstractDiffBackend`](@ref), [`default_differential_backend`](@ref), [`set_default_differential_backend`](@ref)
"""
mutable struct CurrentDiffBackend
backend::AbstractDiffBackend
end

"""
_current_diff_backend
_current_default_differential_backend

The instance of [`Manifolds.CurrentDiffBackend`](@ref) that stores the globally default
differentiation backend.
"""
const _current_diff_backend = CurrentDiffBackend(NoneDiffBackend())

"""
_diff_backends

A vector of valid [`Manifolds.AbstractDiffBackend`](@ref).
const _current_default_differential_backend = CurrentDiffBackend(NoneDiffBackend())
"""
const _diff_backends = AbstractDiffBackend[]
default_differential_backend() -> AbstractDiffBackend

Get the default differentiation backend.
"""
diff_backend() -> AbstractDiffBackend
default_differential_backend() = _current_default_differential_backend.backend

Get the current differentiation backend.
"""
diff_backend() = _current_diff_backend.backend

"""
diff_backend!(backend::AbstractDiffBackend)
set_default_differential_backend!(backend::AbstractDiffBackend)

Set current backend for differentiation to `backend`.
"""
function diff_backend!(backend::AbstractDiffBackend)
_current_diff_backend.backend = backend
function set_default_differential_backend!(backend::AbstractDiffBackend)
_current_default_differential_backend.backend = backend
return backend
end

"""
diff_backends() -> Vector{AbstractDiffBackend}

Get vector of currently valid differentiation backends.
"""
diff_backends() = _diff_backends

_derivative(f, t) = _derivative(f, t, diff_backend())

_derivative!(f, X, t) = _derivative!(f, X, t, diff_backend())

_gradient(f, p) = _gradient(f, p, diff_backend())

_gradient!(f, X, p) = _gradient!(f, X, p, diff_backend())

# Finite differences

"""
FiniteDifferencesBackend(method::FiniteDifferenceMethod = central_fdm(5, 1))

Differentiation backend based on the FiniteDifferences package.
"""
struct FiniteDifferencesBackend{TM<:FiniteDifferenceMethod} <: AbstractDiffBackend
method::TM
end

function FiniteDifferencesBackend()
return FiniteDifferencesBackend(central_fdm(5, 1))
end

push!(_diff_backends, FiniteDifferencesBackend())

diff_backend!(_diff_backends[end])

function _derivative(f, t, backend::FiniteDifferencesBackend)
return backend.method(f, t)
end

function _gradient(f, p, backend::FiniteDifferencesBackend)
return FiniteDifferences.grad(backend.method, f, p)[1]
end

function _jacobian(f, p, backend::FiniteDifferencesBackend)
return FiniteDifferences.jacobian(backend.method, f, p)[1]
end
27 changes: 20 additions & 7 deletions src/differentiation/finite_diff.jl
Original file line number Diff line number Diff line change
@@ -1,25 +1,38 @@

"""
FiniteDiffBackend(method::Val{Symbol} = Val{:central})
FiniteDiffBackend <: AbstractDiffBackend

A type to specify / use differentiation backend based on FiniteDiff package.

Differentiation backend based on FiniteDiff package.
# Constructor
FiniteDiffBackend(method::Val{Symbol} = Val{:central})
"""
struct FiniteDiffBackend{TM<:Val} <: AbstractDiffBackend
method::TM
end

FiniteDiffBackend() = FiniteDiffBackend(Val(:central))

push!(_diff_backends, FiniteDiffBackend())

function _derivative(f, p, backend::FiniteDiffBackend{Method}) where {Method}
function _derivative(f, p, ::FiniteDiffBackend{Method}) where {Method}
return FiniteDiff.finite_difference_derivative(f, p, Method)
end

function _gradient(f, p, backend::FiniteDiffBackend{Method}) where {Method}
function _gradient(f, p, ::FiniteDiffBackend{Method}) where {Method}
return FiniteDiff.finite_difference_gradient(f, p, Method)
end

function _gradient!(f, X, p, backend::FiniteDiffBackend{Method}) where {Method}
function _gradient!(f, X, p, ::FiniteDiffBackend{Method}) where {Method}
return FiniteDiff.finite_difference_gradient!(X, f, p, Method)
end

function _jacobian(f, p, ::FiniteDiffBackend{Method}) where {Method}
return FiniteDiff.finite_difference_jacobian(f, p, Method)
end

function _jacobian!(f, X, p, ::FiniteDiffBackend{Method}) where {Method}
return FiniteDiff.finite_difference_jacobian!(X, f, p, Method)
end

if default_differential_backend() === NoneDiffBackend()
set_default_differential_backend!(FiniteDiffBackend())
end
28 changes: 28 additions & 0 deletions src/differentiation/finite_differences.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
"""
FiniteDifferencesBackend(method::FiniteDifferenceMethod = central_fdm(5, 1))

Differentiation backend based on the FiniteDifferences package.
"""
struct FiniteDifferencesBackend{TM<:FiniteDifferenceMethod} <: AbstractDiffBackend
method::TM
end

function FiniteDifferencesBackend()
return FiniteDifferencesBackend(central_fdm(5, 1))
end

function _derivative(f, t, backend::FiniteDifferencesBackend)
return backend.method(f, t)
end

function _gradient(f, p, backend::FiniteDifferencesBackend)
return FiniteDifferences.grad(backend.method, f, p)[1]
end

function _jacobian(f, p, backend::FiniteDifferencesBackend)
return FiniteDifferences.jacobian(backend.method, f, p)[1]
end

if default_differential_backend() === NoneDiffBackend()
set_default_differential_backend!(FiniteDifferencesBackend())
end
4 changes: 3 additions & 1 deletion src/differentiation/forward_diff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,6 @@ function _jacobian(f, p, ::ForwardDiffBackend)
return ForwardDiff.jacobian(f, p)
end

push!(_diff_backends, ForwardDiffBackend())
if default_differential_backend() === NoneDiffBackend()
set_default_differential_backend!(ForwardDiffBackend())
end
35 changes: 35 additions & 0 deletions src/differentiation/ode.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
function solve_exp_ode(
M::AbstractConnectionManifold,
p,
X;
basis::AbstractBasis=DefaultOrthonormalBasis(),
solver=AutoVern9(Rodas5()),
backend=default_differential_backend(),
retraction::AbstractRetractionMethod=ManifoldsBase.default_retraction_method(M),
kwargs...,
)
d = manifold_dimension(M)
iv = SVector{d}(1:d)
ix = SVector{d}((d + 1):(2 * d))
u0 = allocate(p, 2 * d)
u0[iv] .= X
u0[ix] .= p

function exp_problem(u, params, t)
M = params[1]
dx = u[iv]
q = u[ix]
ddx = allocate(u, Size(d))
du = allocate(u)
Γ = christoffel_symbols_second(M, q, basis; backend=backend, retraction=retraction)
@einsum ddx[k] = -Γ[k, i, j] * dx[i] * dx[j]
du[iv] .= ddx
du[ix] .= dx
return Base.convert(typeof(u), du)
end
params = (M,)
prob = ODEProblem(exp_problem, u0, (0.0, 1.0), params)
sol = solve(prob, solver; kwargs...)
q = sol.u[1][(d + 1):(2 * d)]
return q
end
4 changes: 3 additions & 1 deletion src/differentiation/reverse_diff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,6 @@ function Manifolds._gradient!(f, X, p, ::ReverseDiffBackend)
return ReverseDiff.gradient!(X, f, p)
end

push!(Manifolds._diff_backends, ReverseDiffBackend())
if default_differential_backend() === NoneDiffBackend()
set_default_differential_backend!(ReverseDiffBackend())
end
Loading