Skip to content

Commit

Permalink
Implement translate_diff and inv_diff for all groups (#679)
Browse files Browse the repository at this point in the history
  • Loading branch information
olivierverdier committed Nov 13, 2023
1 parent 5d267d5 commit 1301240
Show file tree
Hide file tree
Showing 9 changed files with 97 additions and 130 deletions.
19 changes: 4 additions & 15 deletions src/groups/addition_operation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -160,25 +160,14 @@ lie_bracket(::AdditionGroupTrait, G::AbstractDecoratorManifold, X, Y) = zero(X)

lie_bracket!(::AdditionGroupTrait, G::AbstractDecoratorManifold, Z, X, Y) = fill!(Z, 0)

function translate_diff(
::AdditionGroupTrait,
G::AbstractDecoratorManifold,
p,
q,
X,
::ActionDirectionAndSide,
)
return X
end

function translate_diff!(
function adjoint_action!(

Check warning on line 163 in src/groups/addition_operation.jl

View check run for this annotation

Codecov / codecov/patch

src/groups/addition_operation.jl#L163

Added line #L163 was not covered by tests
::AdditionGroupTrait,
G::AbstractDecoratorManifold,
Y,
p,
q,
X,
::ActionDirectionAndSide,
::ActionDirection,
)
return copyto!(G, Y, p, X)
return copyto!(Y, X)

Check warning on line 171 in src/groups/addition_operation.jl

View check run for this annotation

Codecov / codecov/patch

src/groups/addition_operation.jl#L171

Added line #L171 was not covered by tests
end

Check warning on line 173 in src/groups/addition_operation.jl

View workflow job for this annotation

GitHub Actions / Format Check

[JuliaFormatter] reported by reviewdog 🐶 Raw Output: src/groups/addition_operation.jl:173:-
37 changes: 22 additions & 15 deletions src/groups/circle_group.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@ end

Base.show(io::IO, ::CircleGroup) = print(io, "CircleGroup()")

adjoint_action(::CircleGroup, p, X) = X
adjoint_action(::CircleGroup, p, X, ::ActionDirection) = X

Check warning on line 27 in src/groups/circle_group.jl

View check run for this annotation

Codecov / codecov/patch

src/groups/circle_group.jl#L27

Added line #L27 was not covered by tests

adjoint_action!(::CircleGroup, Y, p, X) = copyto!(Y, X)
adjoint_action!(::CircleGroup, Y, p, X, ::ActionDirection) = copyto!(Y, X)

Check warning on line 29 in src/groups/circle_group.jl

View check run for this annotation

Codecov / codecov/patch

src/groups/circle_group.jl#L29

Added line #L29 was not covered by tests

function compose(
::MultiplicationGroupTrait,
Expand Down Expand Up @@ -77,22 +77,29 @@ lie_bracket(::CircleGroup, X, Y) = zero(X)

lie_bracket!(::CircleGroup, Z, X, Y) = fill!(Z, 0)

function translate_diff(::GT, p, q, X, ::ActionDirectionAndSide) where {GT<:CircleGroup}
return map(*, p, X)
_common_translate_diff(::Any, p, q, X) = map(*, p, X)
function translate_diff(G::GT, p, q, X, ::LeftForwardAction) where {GT<:CircleGroup}
return _common_translate_diff(G, p, q, X)

Check warning on line 82 in src/groups/circle_group.jl

View check run for this annotation

Codecov / codecov/patch

src/groups/circle_group.jl#L80-L82

Added lines #L80 - L82 were not covered by tests
end
function translate_diff(
::CircleGroup,
::Identity{MultiplicationOperation},
q,
X,
::ActionDirectionAndSide,
)
return X
function translate_diff(G::GT, p, q, X, ::RightForwardAction) where {GT<:CircleGroup}
return _common_translate_diff(G, p, q, X)

Check warning on line 85 in src/groups/circle_group.jl

View check run for this annotation

Codecov / codecov/patch

src/groups/circle_group.jl#L84-L85

Added lines #L84 - L85 were not covered by tests
end

function translate_diff!(G::CircleGroup, Y, p, q, X, conv::ActionDirectionAndSide)
return copyto!(Y, translate_diff(G, p, q, X, conv))
function translate_diff(G::GT, p, q, X, ::LeftBackwardAction) where {GT<:CircleGroup}
return _common_translate_diff(G, p, q, X)

Check warning on line 88 in src/groups/circle_group.jl

View check run for this annotation

Codecov / codecov/patch

src/groups/circle_group.jl#L87-L88

Added lines #L87 - L88 were not covered by tests
end
function translate_diff(G::GT, p, q, X, ::RightBackwardAction) where {GT<:CircleGroup}
return _common_translate_diff(G, p, q, X)

Check warning on line 91 in src/groups/circle_group.jl

View check run for this annotation

Codecov / codecov/patch

src/groups/circle_group.jl#L90-L91

Added lines #L90 - L91 were not covered by tests
end
translate_diff(::CircleGroup, ::Identity{MultiplicationOperation}, q, X, ::LeftForwardAction) = X

Check warning on line 93 in src/groups/circle_group.jl

View workflow job for this annotation

GitHub Actions / Format Check

[JuliaFormatter] reported by reviewdog 🐶 Raw Output: src/groups/circle_group.jl:93:-translate_diff(::CircleGroup, ::Identity{MultiplicationOperation}, q, X, ::LeftForwardAction) = X src/groups/circle_group.jl:94:-translate_diff(::CircleGroup, ::Identity{MultiplicationOperation}, q, X, ::RightForwardAction) = X src/groups/circle_group.jl:95:-translate_diff(::CircleGroup, ::Identity{MultiplicationOperation}, q, X, ::LeftBackwardAction) = X src/groups/circle_group.jl:96:-translate_diff(::CircleGroup, ::Identity{MultiplicationOperation}, q, X, ::RightBackwardAction) = X src/groups/circle_group.jl:93:+function translate_diff( src/groups/circle_group.jl:94:+ ::CircleGroup, src/groups/circle_group.jl:95:+ ::Identity{MultiplicationOperation}, src/groups/circle_group.jl:96:+ q, src/groups/circle_group.jl:97:+ X, src/groups/circle_group.jl:98:+ ::LeftForwardAction, src/groups/circle_group.jl:99:+) src/groups/circle_group.jl:100:+ return X src/groups/circle_group.jl:101:+end src/groups/circle_group.jl:102:+function translate_diff( src/groups/circle_group.jl:103:+ ::CircleGroup, src/groups/circle_group.jl:104:+ ::Identity{MultiplicationOperation}, src/groups/circle_group.jl:105:+ q, src/groups/circle_group.jl:106:+ X, src/groups/circle_group.jl:107:+ ::RightForwardAction, src/groups/circle_group.jl:108:+) src/groups/circle_group.jl:109:+ return X src/groups/circle_group.jl:110:+end src/groups/circle_group.jl:111:+function translate_diff( src/groups/circle_group.jl:112:+ ::CircleGroup, src/groups/circle_group.jl:113:+ ::Identity{MultiplicationOperation}, src/groups/circle_group.jl:114:+ q, src/groups/circle_group.jl:115:+ X, src/groups/circle_group.jl:116:+ ::LeftBackwardAction, src/groups/circle_group.jl:117:+) src/groups/circle_group.jl:118:+ return X src/groups/circle_group.jl:119:+end src/groups/circle_group.jl:120:+function translate_diff( src/groups/circle_group.jl:121:+ ::CircleGroup, src/groups/circle_group.jl:122:+ ::Identity{MultiplicationOperation}, src/groups/circle_group.jl:123:+ q, src/groups/circle_group.jl:124:+ X, src/groups/circle_group.jl:125:+ ::RightBackwardAction, src/groups/circle_group.jl:126:+) src/groups/circle_group.jl:127:+ return X src/groups/circle_group.jl:128:+end
translate_diff(::CircleGroup, ::Identity{MultiplicationOperation}, q, X, ::RightForwardAction) = X
translate_diff(::CircleGroup, ::Identity{MultiplicationOperation}, q, X, ::LeftBackwardAction) = X
translate_diff(::CircleGroup, ::Identity{MultiplicationOperation}, q, X, ::RightBackwardAction) = X

Check warning on line 96 in src/groups/circle_group.jl

View check run for this annotation

Codecov / codecov/patch

src/groups/circle_group.jl#L93-L96

Added lines #L93 - L96 were not covered by tests

_common_translate_diff!(G, Y, p, q, X, conv) = copyto!(Y, translate_diff(G, p, q, X, conv))
translate_diff!(G::CircleGroup, Y, p, q, X, conv::LeftForwardAction) = _common_translate_diff!(G, Y, p, q, X, conv)

Check warning on line 99 in src/groups/circle_group.jl

View workflow job for this annotation

GitHub Actions / Format Check

[JuliaFormatter] reported by reviewdog 🐶 Raw Output: src/groups/circle_group.jl:99:-translate_diff!(G::CircleGroup, Y, p, q, X, conv::LeftForwardAction) = _common_translate_diff!(G, Y, p, q, X, conv) src/groups/circle_group.jl:100:-translate_diff!(G::CircleGroup, Y, p, q, X, conv::RightForwardAction) = _common_translate_diff!(G, Y, p, q, X, conv) src/groups/circle_group.jl:101:-translate_diff!(G::CircleGroup, Y, p, q, X, conv::LeftBackwardAction) = _common_translate_diff!(G, Y, p, q, X, conv) src/groups/circle_group.jl:102:-translate_diff!(G::CircleGroup, Y, p, q, X, conv::RightBackwardAction) = _common_translate_diff!(G, Y, p, q, X, conv) src/groups/circle_group.jl:131:+function translate_diff!(G::CircleGroup, Y, p, q, X, conv::LeftForwardAction) src/groups/circle_group.jl:132:+ return _common_translate_diff!(G, Y, p, q, X, conv) src/groups/circle_group.jl:133:+end src/groups/circle_group.jl:134:+function translate_diff!(G::CircleGroup, Y, p, q, X, conv::RightForwardAction) src/groups/circle_group.jl:135:+ return _common_translate_diff!(G, Y, p, q, X, conv) src/groups/circle_group.jl:136:+end src/groups/circle_group.jl:137:+function translate_diff!(G::CircleGroup, Y, p, q, X, conv::LeftBackwardAction) src/groups/circle_group.jl:138:+ return _common_translate_diff!(G, Y, p, q, X, conv) src/groups/circle_group.jl:139:+end src/groups/circle_group.jl:140:+function translate_diff!(G::CircleGroup, Y, p, q, X, conv::RightBackwardAction) src/groups/circle_group.jl:141:+ return _common_translate_diff!(G, Y, p, q, X, conv) src/groups/circle_group.jl:142:+end
translate_diff!(G::CircleGroup, Y, p, q, X, conv::RightForwardAction) = _common_translate_diff!(G, Y, p, q, X, conv)
translate_diff!(G::CircleGroup, Y, p, q, X, conv::LeftBackwardAction) = _common_translate_diff!(G, Y, p, q, X, conv)
translate_diff!(G::CircleGroup, Y, p, q, X, conv::RightBackwardAction) = _common_translate_diff!(G, Y, p, q, X, conv)

Check warning on line 102 in src/groups/circle_group.jl

View check run for this annotation

Codecov / codecov/patch

src/groups/circle_group.jl#L98-L102

Added lines #L98 - L102 were not covered by tests

function exp_lie(::CircleGroup, X)
return map(X) do imθ
Expand Down
9 changes: 4 additions & 5 deletions src/groups/general_linear.jl
Original file line number Diff line number Diff line change
Expand Up @@ -275,9 +275,8 @@ function Base.show(io::IO, M::GeneralLinear{Tuple{Int},𝔽}) where {𝔽}
return print(io, "GeneralLinear($n, $𝔽; parameter=:field)")
end

translate_diff(::GeneralLinear, p, q, X, ::LeftForwardAction) = X
translate_diff(::GeneralLinear, p, q, X, ::RightBackwardAction) = p \ X * p
# note: this implementation is not optimal
adjoint_action!(::GeneralLinear, Y, p, X, ::LeftAction) = copyto!(Y, p * X * inv(p))

Check warning on line 279 in src/groups/general_linear.jl

View check run for this annotation

Codecov / codecov/patch

src/groups/general_linear.jl#L279

Added line #L279 was not covered by tests
adjoint_action!(::GeneralLinear, Y, p, X, ::RightAction) = copyto!(Y, p \ X * p)

Check warning on line 281 in src/groups/general_linear.jl

View workflow job for this annotation

GitHub Actions / Format Check

[JuliaFormatter] reported by reviewdog 🐶 Raw Output: src/groups/general_linear.jl:281:- src/groups/general_linear.jl:282:-

function translate_diff!(G::GeneralLinear, Y, p, q, X, conv::ActionDirectionAndSide)
return copyto!(Y, translate_diff(G, p, q, X, conv))
end
31 changes: 4 additions & 27 deletions src/groups/general_unitary_groups.jl
Original file line number Diff line number Diff line change
Expand Up @@ -296,45 +296,22 @@ function Random.rand!(rng::AbstractRNG, G::GeneralUnitaryMultiplicationGroup, pX
return pX
end

function translate_diff!(
function adjoint_action!(

Check warning on line 299 in src/groups/general_unitary_groups.jl

View workflow job for this annotation

GitHub Actions / Format Check

[JuliaFormatter] reported by reviewdog 🐶 Raw Output: src/groups/general_unitary_groups.jl:299:-function adjoint_action!( src/groups/general_unitary_groups.jl:300:- G::GeneralUnitaryMultiplicationGroup, src/groups/general_unitary_groups.jl:301:- Y, src/groups/general_unitary_groups.jl:302:- p, src/groups/general_unitary_groups.jl:303:- X, src/groups/general_unitary_groups.jl:304:- ::LeftAction src/groups/general_unitary_groups.jl:305:-) src/groups/general_unitary_groups.jl:299:+function adjoint_action!(G::GeneralUnitaryMultiplicationGroup, Y, p, X, ::LeftAction)

Check warning on line 299 in src/groups/general_unitary_groups.jl

View check run for this annotation

Codecov / codecov/patch

src/groups/general_unitary_groups.jl#L299

Added line #L299 was not covered by tests
G::GeneralUnitaryMultiplicationGroup,
Y,
p,
q,
X,
::LeftForwardAction,
)
return copyto!(G, Y, X)
end
function translate_diff!(
G::GeneralUnitaryMultiplicationGroup,
Y,
p,
q,
X,
::RightForwardAction,
)
copyto!(G, Y, X)
return Y
end
function translate_diff!(
G::GeneralUnitaryMultiplicationGroup,
Y,
p,
q,
X,
::LeftBackwardAction,
::LeftAction
)
copyto!(G, Y, p * X * inv(G, p))
return Y
end
function translate_diff!(
function adjoint_action!(

Check warning on line 309 in src/groups/general_unitary_groups.jl

View workflow job for this annotation

GitHub Actions / Format Check

[JuliaFormatter] reported by reviewdog 🐶 Raw Output: src/groups/general_unitary_groups.jl:309:-function adjoint_action!( src/groups/general_unitary_groups.jl:310:- G::GeneralUnitaryMultiplicationGroup, src/groups/general_unitary_groups.jl:311:- Y, src/groups/general_unitary_groups.jl:312:- p, src/groups/general_unitary_groups.jl:313:- X, src/groups/general_unitary_groups.jl:314:- ::RightAction, src/groups/general_unitary_groups.jl:315:-) src/groups/general_unitary_groups.jl:303:+function adjoint_action!(G::GeneralUnitaryMultiplicationGroup, Y, p, X, ::RightAction)

Check warning on line 309 in src/groups/general_unitary_groups.jl

View check run for this annotation

Codecov / codecov/patch

src/groups/general_unitary_groups.jl#L309

Added line #L309 was not covered by tests
G::GeneralUnitaryMultiplicationGroup,
Y,
p,
q,
X,
::RightBackwardAction,
::RightAction,
)
return copyto!(G, Y, inv(G, p) * X * p)
end
Expand Down
46 changes: 25 additions & 21 deletions src/groups/group.jl
Original file line number Diff line number Diff line change
Expand Up @@ -426,7 +426,7 @@ function is_vector(
end

@doc raw"""
adjoint_action(G::AbstractDecoratorManifold, p, X)
adjoint_action(G::AbstractDecoratorManifold, p, X, dir)
Adjoint action of the element `p` of the Lie group `G` on the element `X`
of the corresponding Lie algebra.
Expand All @@ -443,26 +443,18 @@ where ``e`` is the identity element of `G`.
Note that the adjoint representation of a Lie group isn't generally faithful.
Notably the adjoint representation of SO(2) is trivial.
"""
adjoint_action(G::AbstractDecoratorManifold, p, X)
@trait_function adjoint_action(G::AbstractDecoratorManifold, p, Xₑ)
function adjoint_action(::TraitList{<:IsGroupManifold}, G::AbstractDecoratorManifold, p, Xₑ)
Xₚ = translate_diff(G, p, Identity(G), Xₑ, LeftForwardAction())
Y = inverse_translate_diff(G, p, p, Xₚ, RightBackwardAction())
return Y
end

@trait_function adjoint_action!(G::AbstractDecoratorManifold, Y, p, Xₑ)
function adjoint_action!(
::TraitList{<:IsGroupManifold},
G::AbstractDecoratorManifold,
Y,
p,
Xₑ,
)
Xₚ = translate_diff(G, p, Identity(G), Xₑ, LeftForwardAction())
inverse_translate_diff!(G, Y, p, p, Xₚ, RightBackwardAction())
return Y
end
adjoint_action(G::AbstractDecoratorManifold, p, X, dir)
@trait_function adjoint_action(G::AbstractDecoratorManifold, p, Xₑ, dir)
@trait_function adjoint_action!(G::AbstractDecoratorManifold, Y, p, Xₑ, dir)
function adjoint_action(::TraitList{<:IsGroupManifold}, G::AbstractDecoratorManifold, p, Xₑ, dir)

Check warning on line 449 in src/groups/group.jl

View workflow job for this annotation

GitHub Actions / Format Check

[JuliaFormatter] reported by reviewdog 🐶 Raw Output: src/groups/group.jl:449:-function adjoint_action(::TraitList{<:IsGroupManifold}, G::AbstractDecoratorManifold, p, Xₑ, dir) src/groups/group.jl:449:+function adjoint_action( src/groups/group.jl:450:+ ::TraitList{<:IsGroupManifold}, src/groups/group.jl:451:+ G::AbstractDecoratorManifold, src/groups/group.jl:452:+ p, src/groups/group.jl:453:+ Xₑ, src/groups/group.jl:454:+ dir, src/groups/group.jl:455:+)
Y = allocate_result(G, adjoint_action, Xₑ, p)
return adjoint_action!(G, Y, p, Xₑ, dir)

Check warning on line 451 in src/groups/group.jl

View check run for this annotation

Codecov / codecov/patch

src/groups/group.jl#L449-L451

Added lines #L449 - L451 were not covered by tests
end
# backward compatibility
adjoint_action(G::AbstractDecoratorManifold, p, X) = adjoint_action(G, p, X, LeftAction())
adjoint_action!(G::AbstractDecoratorManifold, Y, p, X) = adjoint_action!(G, Y, p, X, LeftAction())

Check warning on line 455 in src/groups/group.jl

View workflow job for this annotation

GitHub Actions / Format Check

[JuliaFormatter] reported by reviewdog 🐶 Raw Output: src/groups/group.jl:455:-adjoint_action!(G::AbstractDecoratorManifold, Y, p, X) = adjoint_action!(G, Y, p, X, LeftAction()) src/groups/group.jl:461:+function adjoint_action!(G::AbstractDecoratorManifold, Y, p, X) src/groups/group.jl:462:+ return adjoint_action!(G, Y, p, X, LeftAction()) src/groups/group.jl:463:+end

Check warning on line 455 in src/groups/group.jl

View check run for this annotation

Codecov / codecov/patch

src/groups/group.jl#L454-L455

Added lines #L454 - L455 were not covered by tests
# fall back method: the right action is defined from the left action
adjoint_action!(G::AbstractDecoratorManifold, Y, p, X, ::RightAction) = adjoint_action!(G, Y, inv(G, p), X, LeftAction())

Check warning on line 457 in src/groups/group.jl

View workflow job for this annotation

GitHub Actions / Format Check

[JuliaFormatter] reported by reviewdog 🐶 Raw Output: src/groups/group.jl:457:-adjoint_action!(G::AbstractDecoratorManifold, Y, p, X, ::RightAction) = adjoint_action!(G, Y, inv(G, p), X, LeftAction()) src/groups/group.jl:465:+function adjoint_action!(G::AbstractDecoratorManifold, Y, p, X, ::RightAction) src/groups/group.jl:466:+ return adjoint_action!(G, Y, inv(G, p), X, LeftAction()) src/groups/group.jl:467:+end

Check warning on line 457 in src/groups/group.jl

View check run for this annotation

Codecov / codecov/patch

src/groups/group.jl#L457

Added line #L457 was not covered by tests

@doc raw"""
adjoint_inv_diff(G::AbstractDecoratorManifold, p, X)
Expand Down Expand Up @@ -913,6 +905,18 @@ end
X,
conv::ActionDirectionAndSide=LeftForwardAction(),
)
translate_diff(::AbstractDecoratorManifold, ::Any, ::Any, X, ::LeftForwardAction) = X
translate_diff(::AbstractDecoratorManifold, ::Any, ::Any, X, ::RightForwardAction) = X
translate_diff!(G::AbstractDecoratorManifold, Y, ::Any, ::Any, X, ::LeftForwardAction) = copyto!(G, Y, X)

Check warning on line 910 in src/groups/group.jl

View workflow job for this annotation

GitHub Actions / Format Check

[JuliaFormatter] reported by reviewdog 🐶 Raw Output: src/groups/group.jl:910:-translate_diff!(G::AbstractDecoratorManifold, Y, ::Any, ::Any, X, ::LeftForwardAction) = copyto!(G, Y, X) src/groups/group.jl:911:-translate_diff!(G::AbstractDecoratorManifold, Y, ::Any, ::Any, X, ::RightForwardAction) = copyto!(G, Y, X) src/groups/group.jl:912:-translate_diff!(G::AbstractDecoratorManifold, Y, p, ::Any, X, ::LeftBackwardAction) = adjoint_action!(G, Y, p, X, LeftAction()) src/groups/group.jl:913:-translate_diff!(G::AbstractDecoratorManifold, Y, p, ::Any, X, ::RightBackwardAction) = adjoint_action!(G, Y, p, X, RightAction()) src/groups/group.jl:920:+function translate_diff!( src/groups/group.jl:921:+ G::AbstractDecoratorManifold, src/groups/group.jl:922:+ Y, src/groups/group.jl:923:+ ::Any, src/groups/group.jl:924:+ ::Any, src/groups/group.jl:925:+ X, src/groups/group.jl:926:+ ::LeftForwardAction, src/groups/group.jl:927:+) src/groups/group.jl:928:+ return copyto!(G, Y, X) src/groups/group.jl:929:+end src/groups/group.jl:930:+function translate_diff!( src/groups/group.jl:931:+ G::AbstractDecoratorManifold, src/groups/group.jl:932:+ Y, src/groups/group.jl:933:+ ::Any, src/groups/group.jl:934:+ ::Any, src/groups/group.jl:935:+ X, src/groups/group.jl:936:+ ::RightForwardAction, src/groups/group.jl:937:+) src/groups/group.jl:938:+ return copyto!(G, Y, X) src/groups/group.jl:939:+end src/groups/group.jl:940:+function translate_diff!(G::AbstractDecoratorManifold, Y, p, ::Any, X, ::LeftBackwardAction) src/groups/group.jl:941:+ return adjoint_action!(G, Y, p, X, LeftAction()) src/groups/group.jl:942:+end src/groups/group.jl:943:+function translate_diff!( src/groups/group.jl:944:+ G::AbstractDecoratorManifold, src/groups/group.jl:945:+ Y, src/groups/group.jl:946:+ p, src/groups/group.jl:947:+ ::Any, src/groups/group.jl:948:+ X, src/groups/group.jl:949:+ ::RightBackwardAction, src/groups/group.jl:950:+) src/groups/group.jl:951:+ return adjoint_action!(G, Y, p, X, RightAction()) src/groups/group.jl:952:+end
translate_diff!(G::AbstractDecoratorManifold, Y, ::Any, ::Any, X, ::RightForwardAction) = copyto!(G, Y, X)
translate_diff!(G::AbstractDecoratorManifold, Y, p, ::Any, X, ::LeftBackwardAction) = adjoint_action!(G, Y, p, X, LeftAction())
translate_diff!(G::AbstractDecoratorManifold, Y, p, ::Any, X, ::RightBackwardAction) = adjoint_action!(G, Y, p, X, RightAction())

Check warning on line 913 in src/groups/group.jl

View check run for this annotation

Codecov / codecov/patch

src/groups/group.jl#L908-L913

Added lines #L908 - L913 were not covered by tests

translate_diff(::AbstractDecoratorManifold, ::Identity, q, X, ::LeftForwardAction) = X
translate_diff(::AbstractDecoratorManifold, ::Identity, q, X, ::RightForwardAction) = X
translate_diff(::AbstractDecoratorManifold, ::Identity, q, X, ::LeftBackwardAction) = X
translate_diff(::AbstractDecoratorManifold, ::Identity, q, X, ::RightBackwardAction) = X

Check warning on line 918 in src/groups/group.jl

View check run for this annotation

Codecov / codecov/patch

src/groups/group.jl#L915-L918

Added lines #L915 - L918 were not covered by tests


@doc raw"""
inverse_translate_diff(G::AbstractDecoratorManifold, p, q, X, conv::ActionDirectionAndSide=LeftForwardAction())
Expand Down
8 changes: 3 additions & 5 deletions src/groups/heisenberg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -413,9 +413,7 @@ function Base.show(io::IO, M::HeisenbergGroup{Tuple{Int}})
return print(io, "HeisenbergGroup($(n); parameter=:field)")
end

translate_diff(::HeisenbergGroup, p, q, X, ::LeftForwardAction) = X
translate_diff(::HeisenbergGroup, p, q, X, ::RightBackwardAction) = p \ X * p
# note: this implementation is not optimal
adjoint_action!(::HeisenbergGroup, Y, p, X, ::LeftAction) = copyto!(Y, p * X * inv(p))

Check warning on line 417 in src/groups/heisenberg.jl

View check run for this annotation

Codecov / codecov/patch

src/groups/heisenberg.jl#L417

Added line #L417 was not covered by tests
adjoint_action!(::HeisenbergGroup, Y, p, X, ::RightAction) = copyto!(Y, p \ X * p)

function translate_diff!(G::HeisenbergGroup, Y, p, q, X, conv::ActionDirectionAndSide)
return copyto!(Y, translate_diff(G, p, q, X, conv))
end
40 changes: 22 additions & 18 deletions src/groups/power_group.jl
Original file line number Diff line number Diff line change
Expand Up @@ -259,44 +259,48 @@ function inverse_translate!(
return x
end

function translate_diff!(G::PowerGroup, Y, p, q, X, conv::ActionDirectionAndSide)
function _common_power_adjoint_action!(G, Y, p, X, conv)

Check warning on line 262 in src/groups/power_group.jl

View check run for this annotation

Codecov / codecov/patch

src/groups/power_group.jl#L262

Added line #L262 was not covered by tests
GM = G.manifold
N = GM.manifold
rep_size = representation_size(N)
for i in get_iterator(GM)
translate_diff!(
adjoint_action!(

Check warning on line 267 in src/groups/power_group.jl

View check run for this annotation

Codecov / codecov/patch

src/groups/power_group.jl#L267

Added line #L267 was not covered by tests
N,
_write(GM, rep_size, Y, i),
_read(GM, rep_size, p, i),
_read(GM, rep_size, q, i),
_read(GM, rep_size, X, i),
conv,
)
end
return Y
end
function translate_diff!(
G::PowerGroupNestedReplacing,
Y,
p,
q,
X,
conv::ActionDirectionAndSide,
)
adjoint_action!(G::PowerGroup, Y, p, X, conv::LeftAction) = _common_power_adjoint_action!(G, Y, p, X, conv)
function adjoint_action!(G::PowerGroup, Y, p, X, conv::RightAction)
return _common_power_adjoint_action!(G, Y, p, X, conv)

Check warning on line 279 in src/groups/power_group.jl

View check run for this annotation

Codecov / codecov/patch

src/groups/power_group.jl#L277-L279

Added lines #L277 - L279 were not covered by tests
end

function _common_power_replacing_adjoint_action!(G, Y, p, X, conv)

Check warning on line 282 in src/groups/power_group.jl

View check run for this annotation

Codecov / codecov/patch

src/groups/power_group.jl#L282

Added line #L282 was not covered by tests
GM = G.manifold
N = GM.manifold
rep_size = representation_size(N)
for i in get_iterator(GM)
Y[i...] = translate_diff(
N,
_read(GM, rep_size, p, i),
_read(GM, rep_size, q, i),
_read(GM, rep_size, X, i),
conv,
)
Y[i...] =

Check warning on line 287 in src/groups/power_group.jl

View check run for this annotation

Codecov / codecov/patch

src/groups/power_group.jl#L287

Added line #L287 was not covered by tests
adjoint_action(N, _read(GM, rep_size, p, i), _read(GM, rep_size, X, i), conv)
end
return Y
end
function adjoint_action!(G::PowerGroupNestedReplacing, Y, p, X, conv::LeftAction)
return _common_power_replacing_adjoint_action!(G, Y, p, X, conv)

Check warning on line 293 in src/groups/power_group.jl

View check run for this annotation

Codecov / codecov/patch

src/groups/power_group.jl#L292-L293

Added lines #L292 - L293 were not covered by tests
end
function adjoint_action!(

Check warning on line 295 in src/groups/power_group.jl

View check run for this annotation

Codecov / codecov/patch

src/groups/power_group.jl#L295

Added line #L295 was not covered by tests
G::PowerGroupNestedReplacing,
Y,
p,
X,
conv::RightAction,
)
return _common_power_replacing_adjoint_action!(G, Y, p, X, conv)

Check warning on line 302 in src/groups/power_group.jl

View check run for this annotation

Codecov / codecov/patch

src/groups/power_group.jl#L302

Added line #L302 was not covered by tests
end

function inverse_translate_diff!(G::PowerGroup, Y, p, q, X, conv::ActionDirectionAndSide)
GM = G.manifold
Expand Down
26 changes: 9 additions & 17 deletions src/groups/product_group.jl
Original file line number Diff line number Diff line change
Expand Up @@ -205,34 +205,26 @@ function inverse_translate!(G::ProductGroup, x, p, q, conv::ActionDirectionAndSi
return x
end

function translate_diff(G::ProductGroup, p, q, X, conv::ActionDirectionAndSide)
M = G.manifold
return ArrayPartition(
map(
translate_diff,
M.manifolds,
submanifold_components(G, p),
submanifold_components(G, q),
submanifold_components(G, X),
repeated(conv),
)...,
)
end

function translate_diff!(G::ProductGroup, Y, p, q, X, conv::ActionDirectionAndSide)
function _common_product_adjoint_action!(G, Y, p, X, conv)

Check warning on line 208 in src/groups/product_group.jl

View check run for this annotation

Codecov / codecov/patch

src/groups/product_group.jl#L208

Added line #L208 was not covered by tests
M = G.manifold
map(
translate_diff!,
adjoint_action!,
M.manifolds,
submanifold_components(G, Y),
submanifold_components(G, p),
submanifold_components(G, q),
submanifold_components(G, X),
repeated(conv),
)
return Y
end

function adjoint_action!(G::ProductGroup, Y, p, X, conv::LeftAction)
return _common_product_adjoint_action!(G, Y, p, X, conv)

Check warning on line 222 in src/groups/product_group.jl

View check run for this annotation

Codecov / codecov/patch

src/groups/product_group.jl#L221-L222

Added lines #L221 - L222 were not covered by tests
end
function adjoint_action!(G::ProductGroup, Y, p, X, conv::RightAction)
return _common_product_adjoint_action!(G, Y, p, X, conv)

Check warning on line 225 in src/groups/product_group.jl

View check run for this annotation

Codecov / codecov/patch

src/groups/product_group.jl#L224-L225

Added lines #L224 - L225 were not covered by tests
end

function inverse_translate_diff(G::ProductGroup, p, q, X, conv::ActionDirectionAndSide)
M = G.manifold
return ArrayPartition(
Expand Down
11 changes: 4 additions & 7 deletions src/groups/special_linear.jl
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,9 @@ function get_embedding(M::SpecialLinear{Tuple{Int},𝔽}) where {𝔽}
return GeneralLinear(n, 𝔽; parameter=:field)
end

inverse_translate_diff(::SpecialLinear, p, q, X, ::LeftForwardAction) = X
inverse_translate_diff(::SpecialLinear, p, q, X, ::RightBackwardAction) = p * X / p
# note: this implementation is not optimal
adjoint_action!(::SpecialLinear, Y, p, X, ::LeftAction) = copyto!(Y, p * X * inv(p))
adjoint_action!(::SpecialLinear, Y, p, X, ::RightAction) = copyto!(Y, p \ X * p)

Check warning on line 77 in src/groups/special_linear.jl

View check run for this annotation

Codecov / codecov/patch

src/groups/special_linear.jl#L76-L77

Added lines #L76 - L77 were not covered by tests

function inverse_translate_diff!(G::SpecialLinear, Y, p, q, X, conv::ActionDirectionAndSide)
return copyto!(Y, inverse_translate_diff(G, p, q, X, conv))
Expand Down Expand Up @@ -147,9 +148,5 @@ function Base.show(io::IO, M::SpecialLinear{Tuple{Int},𝔽}) where {𝔽}
return print(io, "SpecialLinear($n, $𝔽; parameter=:field)")
end

translate_diff(::SpecialLinear, p, q, X, ::LeftForwardAction) = X
translate_diff(::SpecialLinear, p, q, X, ::RightBackwardAction) = p \ X * p

function translate_diff!(G::SpecialLinear, Y, p, q, X, conv::ActionDirectionAndSide)
return copyto!(Y, translate_diff(G, p, q, X, conv))
end
adjoint_action!(G::SpecialLinear, Y, p, q, X, conv::LeftAction) = p \ X * p

0 comments on commit 1301240

Please sign in to comment.