From 13012406a348ed0a6324a9474e0fbfef3d333cd2 Mon Sep 17 00:00:00 2001 From: Olivier Verdier Date: Mon, 13 Nov 2023 16:16:05 +0100 Subject: [PATCH] Implement `translate_diff` and `inv_diff` for all groups (#679) --- src/groups/addition_operation.jl | 19 +++--------- src/groups/circle_group.jl | 37 +++++++++++++--------- src/groups/general_linear.jl | 9 +++--- src/groups/general_unitary_groups.jl | 31 +++---------------- src/groups/group.jl | 46 +++++++++++++++------------- src/groups/heisenberg.jl | 8 ++--- src/groups/power_group.jl | 40 +++++++++++++----------- src/groups/product_group.jl | 26 ++++++---------- src/groups/special_linear.jl | 11 +++---- 9 files changed, 97 insertions(+), 130 deletions(-) diff --git a/src/groups/addition_operation.jl b/src/groups/addition_operation.jl index d4a8920e7c..539a3deca7 100644 --- a/src/groups/addition_operation.jl +++ b/src/groups/addition_operation.jl @@ -160,25 +160,14 @@ lie_bracket(::AdditionGroupTrait, G::AbstractDecoratorManifold, X, Y) = zero(X) lie_bracket!(::AdditionGroupTrait, G::AbstractDecoratorManifold, Z, X, Y) = fill!(Z, 0) -function translate_diff( - ::AdditionGroupTrait, - G::AbstractDecoratorManifold, - p, - q, - X, - ::ActionDirectionAndSide, -) - return X -end - -function translate_diff!( +function adjoint_action!( ::AdditionGroupTrait, G::AbstractDecoratorManifold, Y, p, - q, X, - ::ActionDirectionAndSide, + ::ActionDirection, ) - return copyto!(G, Y, p, X) + return copyto!(Y, X) end + diff --git a/src/groups/circle_group.jl b/src/groups/circle_group.jl index 34c97fb430..fb965ebe5b 100644 --- a/src/groups/circle_group.jl +++ b/src/groups/circle_group.jl @@ -24,9 +24,9 @@ end Base.show(io::IO, ::CircleGroup) = print(io, "CircleGroup()") -adjoint_action(::CircleGroup, p, X) = X +adjoint_action(::CircleGroup, p, X, ::ActionDirection) = X -adjoint_action!(::CircleGroup, Y, p, X) = copyto!(Y, X) +adjoint_action!(::CircleGroup, Y, p, X, ::ActionDirection) = copyto!(Y, X) function compose( ::MultiplicationGroupTrait, @@ -77,22 +77,29 @@ lie_bracket(::CircleGroup, X, Y) = zero(X) lie_bracket!(::CircleGroup, Z, X, Y) = fill!(Z, 0) -function translate_diff(::GT, p, q, X, ::ActionDirectionAndSide) where {GT<:CircleGroup} - return map(*, p, X) +_common_translate_diff(::Any, p, q, X) = map(*, p, X) +function translate_diff(G::GT, p, q, X, ::LeftForwardAction) where {GT<:CircleGroup} + return _common_translate_diff(G, p, q, X) end -function translate_diff( - ::CircleGroup, - ::Identity{MultiplicationOperation}, - q, - X, - ::ActionDirectionAndSide, -) - return X +function translate_diff(G::GT, p, q, X, ::RightForwardAction) where {GT<:CircleGroup} + return _common_translate_diff(G, p, q, X) end - -function translate_diff!(G::CircleGroup, Y, p, q, X, conv::ActionDirectionAndSide) - return copyto!(Y, translate_diff(G, p, q, X, conv)) +function translate_diff(G::GT, p, q, X, ::LeftBackwardAction) where {GT<:CircleGroup} + return _common_translate_diff(G, p, q, X) end +function translate_diff(G::GT, p, q, X, ::RightBackwardAction) where {GT<:CircleGroup} + return _common_translate_diff(G, p, q, X) +end +translate_diff(::CircleGroup, ::Identity{MultiplicationOperation}, q, X, ::LeftForwardAction) = X +translate_diff(::CircleGroup, ::Identity{MultiplicationOperation}, q, X, ::RightForwardAction) = X +translate_diff(::CircleGroup, ::Identity{MultiplicationOperation}, q, X, ::LeftBackwardAction) = X +translate_diff(::CircleGroup, ::Identity{MultiplicationOperation}, q, X, ::RightBackwardAction) = X + +_common_translate_diff!(G, Y, p, q, X, conv) = copyto!(Y, translate_diff(G, p, q, X, conv)) +translate_diff!(G::CircleGroup, Y, p, q, X, conv::LeftForwardAction) = _common_translate_diff!(G, Y, p, q, X, conv) +translate_diff!(G::CircleGroup, Y, p, q, X, conv::RightForwardAction) = _common_translate_diff!(G, Y, p, q, X, conv) +translate_diff!(G::CircleGroup, Y, p, q, X, conv::LeftBackwardAction) = _common_translate_diff!(G, Y, p, q, X, conv) +translate_diff!(G::CircleGroup, Y, p, q, X, conv::RightBackwardAction) = _common_translate_diff!(G, Y, p, q, X, conv) function exp_lie(::CircleGroup, X) return map(X) do imθ diff --git a/src/groups/general_linear.jl b/src/groups/general_linear.jl index 8f2f24d21b..d4b58eb66f 100644 --- a/src/groups/general_linear.jl +++ b/src/groups/general_linear.jl @@ -275,9 +275,8 @@ function Base.show(io::IO, M::GeneralLinear{Tuple{Int},𝔽}) where {𝔽} return print(io, "GeneralLinear($n, $𝔽; parameter=:field)") end -translate_diff(::GeneralLinear, p, q, X, ::LeftForwardAction) = X -translate_diff(::GeneralLinear, p, q, X, ::RightBackwardAction) = p \ X * p +# note: this implementation is not optimal +adjoint_action!(::GeneralLinear, Y, p, X, ::LeftAction) = copyto!(Y, p * X * inv(p)) +adjoint_action!(::GeneralLinear, Y, p, X, ::RightAction) = copyto!(Y, p \ X * p) + -function translate_diff!(G::GeneralLinear, Y, p, q, X, conv::ActionDirectionAndSide) - return copyto!(Y, translate_diff(G, p, q, X, conv)) -end diff --git a/src/groups/general_unitary_groups.jl b/src/groups/general_unitary_groups.jl index a280caba26..a9f63b2a64 100644 --- a/src/groups/general_unitary_groups.jl +++ b/src/groups/general_unitary_groups.jl @@ -296,45 +296,22 @@ function Random.rand!(rng::AbstractRNG, G::GeneralUnitaryMultiplicationGroup, pX return pX end -function translate_diff!( +function adjoint_action!( G::GeneralUnitaryMultiplicationGroup, Y, p, - q, - X, - ::LeftForwardAction, -) - return copyto!(G, Y, X) -end -function translate_diff!( - G::GeneralUnitaryMultiplicationGroup, - Y, - p, - q, X, - ::RightForwardAction, -) - copyto!(G, Y, X) - return Y -end -function translate_diff!( - G::GeneralUnitaryMultiplicationGroup, - Y, - p, - q, - X, - ::LeftBackwardAction, + ::LeftAction ) copyto!(G, Y, p * X * inv(G, p)) return Y end -function translate_diff!( +function adjoint_action!( G::GeneralUnitaryMultiplicationGroup, Y, p, - q, X, - ::RightBackwardAction, + ::RightAction, ) return copyto!(G, Y, inv(G, p) * X * p) end diff --git a/src/groups/group.jl b/src/groups/group.jl index 9d138b975e..fb53c52420 100644 --- a/src/groups/group.jl +++ b/src/groups/group.jl @@ -426,7 +426,7 @@ function is_vector( end @doc raw""" - adjoint_action(G::AbstractDecoratorManifold, p, X) + adjoint_action(G::AbstractDecoratorManifold, p, X, dir) Adjoint action of the element `p` of the Lie group `G` on the element `X` of the corresponding Lie algebra. @@ -443,26 +443,18 @@ where ``e`` is the identity element of `G`. Note that the adjoint representation of a Lie group isn't generally faithful. Notably the adjoint representation of SO(2) is trivial. """ -adjoint_action(G::AbstractDecoratorManifold, p, X) -@trait_function adjoint_action(G::AbstractDecoratorManifold, p, Xₑ) -function adjoint_action(::TraitList{<:IsGroupManifold}, G::AbstractDecoratorManifold, p, Xₑ) - Xₚ = translate_diff(G, p, Identity(G), Xₑ, LeftForwardAction()) - Y = inverse_translate_diff(G, p, p, Xₚ, RightBackwardAction()) - return Y -end - -@trait_function adjoint_action!(G::AbstractDecoratorManifold, Y, p, Xₑ) -function adjoint_action!( - ::TraitList{<:IsGroupManifold}, - G::AbstractDecoratorManifold, - Y, - p, - Xₑ, -) - Xₚ = translate_diff(G, p, Identity(G), Xₑ, LeftForwardAction()) - inverse_translate_diff!(G, Y, p, p, Xₚ, RightBackwardAction()) - return Y -end +adjoint_action(G::AbstractDecoratorManifold, p, X, dir) +@trait_function adjoint_action(G::AbstractDecoratorManifold, p, Xₑ, dir) +@trait_function adjoint_action!(G::AbstractDecoratorManifold, Y, p, Xₑ, dir) +function adjoint_action(::TraitList{<:IsGroupManifold}, G::AbstractDecoratorManifold, p, Xₑ, dir) + Y = allocate_result(G, adjoint_action, Xₑ, p) + return adjoint_action!(G, Y, p, Xₑ, dir) +end +# backward compatibility +adjoint_action(G::AbstractDecoratorManifold, p, X) = adjoint_action(G, p, X, LeftAction()) +adjoint_action!(G::AbstractDecoratorManifold, Y, p, X) = adjoint_action!(G, Y, p, X, LeftAction()) +# fall back method: the right action is defined from the left action +adjoint_action!(G::AbstractDecoratorManifold, Y, p, X, ::RightAction) = adjoint_action!(G, Y, inv(G, p), X, LeftAction()) @doc raw""" adjoint_inv_diff(G::AbstractDecoratorManifold, p, X) @@ -913,6 +905,18 @@ end X, conv::ActionDirectionAndSide=LeftForwardAction(), ) +translate_diff(::AbstractDecoratorManifold, ::Any, ::Any, X, ::LeftForwardAction) = X +translate_diff(::AbstractDecoratorManifold, ::Any, ::Any, X, ::RightForwardAction) = X +translate_diff!(G::AbstractDecoratorManifold, Y, ::Any, ::Any, X, ::LeftForwardAction) = copyto!(G, Y, X) +translate_diff!(G::AbstractDecoratorManifold, Y, ::Any, ::Any, X, ::RightForwardAction) = copyto!(G, Y, X) +translate_diff!(G::AbstractDecoratorManifold, Y, p, ::Any, X, ::LeftBackwardAction) = adjoint_action!(G, Y, p, X, LeftAction()) +translate_diff!(G::AbstractDecoratorManifold, Y, p, ::Any, X, ::RightBackwardAction) = adjoint_action!(G, Y, p, X, RightAction()) + +translate_diff(::AbstractDecoratorManifold, ::Identity, q, X, ::LeftForwardAction) = X +translate_diff(::AbstractDecoratorManifold, ::Identity, q, X, ::RightForwardAction) = X +translate_diff(::AbstractDecoratorManifold, ::Identity, q, X, ::LeftBackwardAction) = X +translate_diff(::AbstractDecoratorManifold, ::Identity, q, X, ::RightBackwardAction) = X + @doc raw""" inverse_translate_diff(G::AbstractDecoratorManifold, p, q, X, conv::ActionDirectionAndSide=LeftForwardAction()) diff --git a/src/groups/heisenberg.jl b/src/groups/heisenberg.jl index 2c5793cae3..3554c53497 100644 --- a/src/groups/heisenberg.jl +++ b/src/groups/heisenberg.jl @@ -413,9 +413,7 @@ function Base.show(io::IO, M::HeisenbergGroup{Tuple{Int}}) return print(io, "HeisenbergGroup($(n); parameter=:field)") end -translate_diff(::HeisenbergGroup, p, q, X, ::LeftForwardAction) = X -translate_diff(::HeisenbergGroup, p, q, X, ::RightBackwardAction) = p \ X * p +# note: this implementation is not optimal +adjoint_action!(::HeisenbergGroup, Y, p, X, ::LeftAction) = copyto!(Y, p * X * inv(p)) +adjoint_action!(::HeisenbergGroup, Y, p, X, ::RightAction) = copyto!(Y, p \ X * p) -function translate_diff!(G::HeisenbergGroup, Y, p, q, X, conv::ActionDirectionAndSide) - return copyto!(Y, translate_diff(G, p, q, X, conv)) -end diff --git a/src/groups/power_group.jl b/src/groups/power_group.jl index d1bf597f81..d6a5f17853 100644 --- a/src/groups/power_group.jl +++ b/src/groups/power_group.jl @@ -259,44 +259,48 @@ function inverse_translate!( return x end -function translate_diff!(G::PowerGroup, Y, p, q, X, conv::ActionDirectionAndSide) +function _common_power_adjoint_action!(G, Y, p, X, conv) GM = G.manifold N = GM.manifold rep_size = representation_size(N) for i in get_iterator(GM) - translate_diff!( + adjoint_action!( N, _write(GM, rep_size, Y, i), _read(GM, rep_size, p, i), - _read(GM, rep_size, q, i), _read(GM, rep_size, X, i), conv, ) end return Y end -function translate_diff!( - G::PowerGroupNestedReplacing, - Y, - p, - q, - X, - conv::ActionDirectionAndSide, -) +adjoint_action!(G::PowerGroup, Y, p, X, conv::LeftAction) = _common_power_adjoint_action!(G, Y, p, X, conv) +function adjoint_action!(G::PowerGroup, Y, p, X, conv::RightAction) + return _common_power_adjoint_action!(G, Y, p, X, conv) +end + +function _common_power_replacing_adjoint_action!(G, Y, p, X, conv) GM = G.manifold N = GM.manifold rep_size = representation_size(N) for i in get_iterator(GM) - Y[i...] = translate_diff( - N, - _read(GM, rep_size, p, i), - _read(GM, rep_size, q, i), - _read(GM, rep_size, X, i), - conv, - ) + Y[i...] = + adjoint_action(N, _read(GM, rep_size, p, i), _read(GM, rep_size, X, i), conv) end return Y end +function adjoint_action!(G::PowerGroupNestedReplacing, Y, p, X, conv::LeftAction) + return _common_power_replacing_adjoint_action!(G, Y, p, X, conv) +end +function adjoint_action!( + G::PowerGroupNestedReplacing, + Y, + p, + X, + conv::RightAction, +) + return _common_power_replacing_adjoint_action!(G, Y, p, X, conv) +end function inverse_translate_diff!(G::PowerGroup, Y, p, q, X, conv::ActionDirectionAndSide) GM = G.manifold diff --git a/src/groups/product_group.jl b/src/groups/product_group.jl index 7565c5f925..fd537f5dd0 100644 --- a/src/groups/product_group.jl +++ b/src/groups/product_group.jl @@ -205,34 +205,26 @@ function inverse_translate!(G::ProductGroup, x, p, q, conv::ActionDirectionAndSi return x end -function translate_diff(G::ProductGroup, p, q, X, conv::ActionDirectionAndSide) - M = G.manifold - return ArrayPartition( - map( - translate_diff, - M.manifolds, - submanifold_components(G, p), - submanifold_components(G, q), - submanifold_components(G, X), - repeated(conv), - )..., - ) -end - -function translate_diff!(G::ProductGroup, Y, p, q, X, conv::ActionDirectionAndSide) +function _common_product_adjoint_action!(G, Y, p, X, conv) M = G.manifold map( - translate_diff!, + adjoint_action!, M.manifolds, submanifold_components(G, Y), submanifold_components(G, p), - submanifold_components(G, q), submanifold_components(G, X), repeated(conv), ) return Y end +function adjoint_action!(G::ProductGroup, Y, p, X, conv::LeftAction) + return _common_product_adjoint_action!(G, Y, p, X, conv) +end +function adjoint_action!(G::ProductGroup, Y, p, X, conv::RightAction) + return _common_product_adjoint_action!(G, Y, p, X, conv) +end + function inverse_translate_diff(G::ProductGroup, p, q, X, conv::ActionDirectionAndSide) M = G.manifold return ArrayPartition( diff --git a/src/groups/special_linear.jl b/src/groups/special_linear.jl index 7ffb5fde1e..fd57f987f3 100644 --- a/src/groups/special_linear.jl +++ b/src/groups/special_linear.jl @@ -72,8 +72,9 @@ function get_embedding(M::SpecialLinear{Tuple{Int},𝔽}) where {𝔽} return GeneralLinear(n, 𝔽; parameter=:field) end -inverse_translate_diff(::SpecialLinear, p, q, X, ::LeftForwardAction) = X -inverse_translate_diff(::SpecialLinear, p, q, X, ::RightBackwardAction) = p * X / p +# note: this implementation is not optimal +adjoint_action!(::SpecialLinear, Y, p, X, ::LeftAction) = copyto!(Y, p * X * inv(p)) +adjoint_action!(::SpecialLinear, Y, p, X, ::RightAction) = copyto!(Y, p \ X * p) function inverse_translate_diff!(G::SpecialLinear, Y, p, q, X, conv::ActionDirectionAndSide) return copyto!(Y, inverse_translate_diff(G, p, q, X, conv)) @@ -147,9 +148,5 @@ function Base.show(io::IO, M::SpecialLinear{Tuple{Int},𝔽}) where {𝔽} return print(io, "SpecialLinear($n, $𝔽; parameter=:field)") end -translate_diff(::SpecialLinear, p, q, X, ::LeftForwardAction) = X -translate_diff(::SpecialLinear, p, q, X, ::RightBackwardAction) = p \ X * p -function translate_diff!(G::SpecialLinear, Y, p, q, X, conv::ActionDirectionAndSide) - return copyto!(Y, translate_diff(G, p, q, X, conv)) -end +adjoint_action!(G::SpecialLinear, Y, p, q, X, conv::LeftAction) = p \ X * p