From 9bf151ddc62254503446b8f2a1313abeda82f019 Mon Sep 17 00:00:00 2001 From: Mateusz Baran Date: Tue, 7 Jun 2022 17:03:13 +0200 Subject: [PATCH] Fixing issues with SE(n) (#491) * Fixing issues with SE(n) * GL needs to overload non-mutating exp and log * fix manifold operations on direct product group * basis tests for SE(n) * bump version --- Project.toml | 2 +- src/groups/general_linear.jl | 10 ++++-- src/groups/metric.jl | 54 ++++++++++---------------------- src/groups/product_group.jl | 13 ++++++++ src/groups/special_euclidean.jl | 31 +++++++++++++++++- test/groups/product_group.jl | 1 + test/groups/special_euclidean.jl | 20 ++++++++++++ 7 files changed, 89 insertions(+), 42 deletions(-) diff --git a/Project.toml b/Project.toml index 6673306d8e..3f5e46b3d4 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Manifolds" uuid = "1cead3c2-87b3-11e9-0ccd-23c62b72b94e" authors = ["Seth Axen ", "Mateusz Baran ", "Ronny Bergmann ", "Antoine Levitt "] -version = "0.8.8" +version = "0.8.9" [deps] Colors = "5ae59095-9a9b-59fe-a467-6f913c188581" diff --git a/src/groups/general_linear.jl b/src/groups/general_linear.jl index b685c8081d..99ce1215e5 100644 --- a/src/groups/general_linear.jl +++ b/src/groups/general_linear.jl @@ -80,7 +80,10 @@ the conjugate transpose. [^AndruchowLarotondaRechtVarela2014][^MartinNeff2016] > doi: [10.3934/jgm.2016010](https://doi.org/10.3934/jgm.2016010), > arXiv: [1409.7849v2](https://arxiv.org/abs/1409.7849v2). """ -exp(::GeneralLinear, p, X) +function exp(M::GeneralLinear, p, X) + q = similar(p) + return exp!(M, q, p, X) +end function exp!(G::GeneralLinear, q, p, X) expX = exp(X) @@ -187,7 +190,10 @@ supergroup `GeneralLinear(2n)` and the resulting tangent vector is then complexi Note that this implementation is experimental. """ -log(::GeneralLinear, p, q) +function log(M::GeneralLinear, p, q) + X = similar(p) + return log!(M, X, p, q) +end function log!(G::GeneralLinear{n,𝔽}, X, p, q) where {n,𝔽} pinvq = inverse_translate(G, p, q, LeftAction()) diff --git a/src/groups/metric.jl b/src/groups/metric.jl index e0ab8f6783..17dd7eb538 100644 --- a/src/groups/metric.jl +++ b/src/groups/metric.jl @@ -65,17 +65,17 @@ direction(::TraitList{HasLeftInvariantMetric}, ::AbstractDecoratorManifold) = Le direction(::TraitList{HasRightInvariantMetric}, ::AbstractDecoratorManifold) = RightAction() -function exp(::TraitList{HasLeftInvariantMetric}, M::MetricManifold, p, X) - return retract(M.manifold, p, X, GroupExponentialRetraction(LeftAction())) +function exp(::TraitList{HasLeftInvariantMetric}, M::AbstractDecoratorManifold, p, X) + return retract(M, p, X, GroupExponentialRetraction(LeftAction())) end -function exp!(::TraitList{HasLeftInvariantMetric}, M::MetricManifold, q, p, X) - return retract!(M.manifold, q, p, X, GroupExponentialRetraction(LeftAction())) +function exp!(::TraitList{HasLeftInvariantMetric}, M::AbstractDecoratorManifold, q, p, X) + return retract!(M, q, p, X, GroupExponentialRetraction(LeftAction())) end -function exp(::TraitList{HasRightInvariantMetric}, M::MetricManifold, p, X) - return retract(M.manifold, p, X, GroupExponentialRetraction(RightAction())) +function exp(::TraitList{HasRightInvariantMetric}, M::AbstractDecoratorManifold, p, X) + return retract(M, p, X, GroupExponentialRetraction(RightAction())) end -function exp!(::TraitList{HasRightInvariantMetric}, M::MetricManifold, q, p, X) - return retract!(M.manifold, q, p, X, GroupExponentialRetraction(RightAction())) +function exp!(::TraitList{HasRightInvariantMetric}, M::AbstractDecoratorManifold, q, p, X) + return retract!(M, q, p, X, GroupExponentialRetraction(RightAction())) end function exp(::TraitList{HasBiinvariantMetric}, M::MetricManifold, p, X) return exp(M.manifold, p, X) @@ -158,39 +158,17 @@ function inverse_translate_diff!( return inverse_translate_diff!(M.manifold, Y, p, q, X, conv) end -function log(::TraitList{HasLeftInvariantMetric}, M::MetricManifold, p, q) - return inverse_retract( - M.manifold, - p, - q, - GroupLogarithmicInverseRetraction(LeftAction()), - ) +function log(::TraitList{HasLeftInvariantMetric}, M::AbstractDecoratorManifold, p, q) + return inverse_retract(M, p, q, GroupLogarithmicInverseRetraction(LeftAction())) end -function log!(::TraitList{HasLeftInvariantMetric}, M::MetricManifold, X, p, q) - return inverse_retract!( - M.manifold, - X, - p, - q, - GroupLogarithmicInverseRetraction(LeftAction()), - ) +function log!(::TraitList{HasLeftInvariantMetric}, M::AbstractDecoratorManifold, X, p, q) + return inverse_retract!(M, X, p, q, GroupLogarithmicInverseRetraction(LeftAction())) end -function log(::TraitList{HasRightInvariantMetric}, M::MetricManifold, p, q) - return inverse_retract( - M.manifold, - p, - q, - GroupLogarithmicInverseRetraction(RightAction()), - ) +function log(::TraitList{HasRightInvariantMetric}, M::AbstractDecoratorManifold, p, q) + return inverse_retract(M, p, q, GroupLogarithmicInverseRetraction(RightAction())) end -function log!(::TraitList{HasRightInvariantMetric}, M::MetricManifold, X, p, q) - return inverse_retract!( - M.manifold, - X, - p, - q, - GroupLogarithmicInverseRetraction(RightAction()), - ) +function log!(::TraitList{HasRightInvariantMetric}, M::AbstractDecoratorManifold, X, p, q) + return inverse_retract!(M, X, p, q, GroupLogarithmicInverseRetraction(RightAction())) end function log(::TraitList{HasBiinvariantMetric}, M::MetricManifold, p, q) return log(M.manifold, p, q) diff --git a/src/groups/product_group.jl b/src/groups/product_group.jl index 4c32482381..fa0f91a377 100644 --- a/src/groups/product_group.jl +++ b/src/groups/product_group.jl @@ -27,6 +27,19 @@ function ProductGroup(manifold::ProductManifold{𝔽}) where {𝔽} return GroupManifold(manifold, op) end +@inline function active_traits(f, M::ProductGroup, args...) + if is_metric_function(f) + #pass to manifold by default - but keep Group Decorator for the retraction + return merge_traits(IsGroupManifold(M.op), IsExplicitDecorator()) + else + return merge_traits( + IsGroupManifold(M.op), + active_traits(f, M.manifold, args...), + IsExplicitDecorator(), + ) + end +end + function identity_element(G::ProductGroup) M = G.manifold return ProductRepr(map(identity_element, M.manifolds)) diff --git a/src/groups/special_euclidean.jl b/src/groups/special_euclidean.jl index ced2a8362c..5e6fa772c6 100644 --- a/src/groups/special_euclidean.jl +++ b/src/groups/special_euclidean.jl @@ -49,6 +49,34 @@ const SpecialEuclideanIdentity{N} = Identity{SpecialEuclideanOperation{N}} Base.show(io::IO, ::SpecialEuclidean{n}) where {n} = print(io, "SpecialEuclidean($(n))") +_is_se_forwarded_function(::Any) = false +for mf in [ + flat!, + get_basis, + get_coordinates, + get_coordinates!, + get_vector, + get_vector!, + get_vectors, + inner, + norm, + sharp!, +] + @eval _is_se_forwarded_function(::typeof($mf)) = true +end + +@inline function active_traits(f, M::SpecialEuclidean, args...) + if is_metric_function(f) && !_is_se_forwarded_function(f) + return merge_traits(IsGroupManifold(M.op), HasLeftInvariantMetric()) + else + return merge_traits( + IsGroupManifold(M.op), + active_traits(f, M.manifold, args...), + IsExplicitDecorator(), + ) + end +end + Base.@propagate_inbounds function submanifold_component( ::Union{SpecialEuclidean{n},SpecialEuclideanManifold{n}}, p::AbstractMatrix, @@ -260,7 +288,8 @@ function compose!( p::AbstractMatrix, q::AbstractMatrix, ) - return mul!(x, p, q) + copyto!(x, p * q) + return x end @doc raw""" diff --git a/test/groups/product_group.jl b/test/groups/product_group.jl index 5da40d2e46..58db19e495 100644 --- a/test/groups/product_group.jl +++ b/test/groups/product_group.jl @@ -70,6 +70,7 @@ include("group_utils.jl") @test compose(G, pts[1], Identity(G)) == pts[1] @test compose(G, Identity(G), pts[1]) == pts[1] test_group(G, pts, X_pts, X_pts; test_diff=true, test_mutating=false) + test_manifold(G, pts; is_mutating=false) @test isapprox( G, exp_lie(G, X_pts[1]), diff --git a/test/groups/special_euclidean.jl b/test/groups/special_euclidean.jl index 3ff00aa737..3e16191226 100644 --- a/test/groups/special_euclidean.jl +++ b/test/groups/special_euclidean.jl @@ -44,6 +44,8 @@ Random.seed!(10) ] end + basis_types = (DefaultOrthonormalBasis(),) + @testset "product repr" begin pts = [ProductRepr(tp...) for tp in tuple_pts] X_pts = [ProductRepr(tX...) for tX in tuple_X] @@ -77,6 +79,16 @@ Random.seed!(10) test_adjoint_action=true, diff_convs=[(), (LeftAction(),), (RightAction(),)], ) + test_manifold( + G, + pts; + basis_types_vecs=basis_types, + basis_types_to_from=basis_types, + is_mutating=true, + #test_inplace=true, + test_vee_hat=true, + exp_log_atol_multiplier=50, + ) end @testset "affine matrix" begin @@ -92,6 +104,14 @@ Random.seed!(10) diff_convs=[(), (LeftAction(),), (RightAction(),)], atol=1e-9, ) + test_manifold( + G, + pts; + is_mutating=true, + #test_inplace=true, + test_vee_hat=true, + exp_log_atol_multiplier=50, + ) # specific affine tests p = copy(G, pts[1]) X = copy(G, p, X_pts[1])