Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/mbaran/special-euclidean-revisit…
Browse files Browse the repository at this point in the history
…ed' into mbaran/special-euclidean-revisited
  • Loading branch information
mateuszbaran committed Oct 18, 2023
2 parents 30c1597 + 3732499 commit 63f5364
Show file tree
Hide file tree
Showing 22 changed files with 78 additions and 55 deletions.
26 changes: 18 additions & 8 deletions src/groups/GroupManifold.jl
Original file line number Diff line number Diff line change
Expand Up @@ -69,12 +69,17 @@ end
function is_point(
::TraitList{<:IsGroupManifold},
G::GroupManifold,
e::Identity,
te::Bool=false;
e::Identity;
error::Symbol=:none,
kwargs...,
)
ie = is_identity(G, e; kwargs...)
(te && !ie) && throw(DomainError(e, "The provided identity is not a point on $G."))
if !ie
s = "The provided identity is not a point on $G."
(error === :error) && throw(DomainError(e, s))
(error === :info) && @info s
(error === :warn) && @warn s
end
return ie
end

Expand All @@ -83,16 +88,21 @@ function is_vector(
G::GroupManifold,
e::Identity,
X,
te::Bool=false,
cbp=true;
cbp::Bool;
error::Symbol=:none,
kwargs...,
)
if cbp
ie = is_identity(G, e; kwargs...)
(te && !ie) && throw(DomainError(e, "The provided identity is not a point on $G."))
(!te && !ie) && return false
if !ie
s = "The provided identity is not a point on $G."
(error === :error) && throw(DomainError(e, s))
(error === :info) && @info s
(error === :warn) && @warn s
return false
end
end
return is_vector(G.manifold, identity_element(G), X, te, false; kwargs...)
return is_vector(G.manifold, identity_element(G), X, false, te; kwargs...)
end

Base.show(io::IO, G::GroupManifold) = print(io, "GroupManifold($(G.manifold), $(G.op))")
31 changes: 22 additions & 9 deletions src/groups/group.jl
Original file line number Diff line number Diff line change
Expand Up @@ -384,12 +384,17 @@ end
function is_point(
::TraitList{<:IsGroupManifold},
G::AbstractDecoratorManifold,
e::Identity,
te::Bool=false;
e::Identity;
error::Symbol=:none,
kwargs...,
)
ie = is_identity(G, e; kwargs...)
(te && !ie) && throw(DomainError(e, "The provided identity is not a point on $G."))
if !ie
s = "The provided identity is not a point on $G."
(error === :error) && throw(DomainError(e, s))
(error === :info) && @info s
(error === :warn) && @warn s
end
return ie
end

Expand All @@ -398,16 +403,24 @@ function is_vector(
G::AbstractDecoratorManifold,
e::Identity,
X,
te::Bool=false,
cbp=true;
cbp::Bool=true;
error::Symbol=:none,
kwargs...,
)
if cbp
# pass te down so this throws an error if te=true
# if !te and is_point was false -> return false, otherwise continue
(!te && !is_point(G, e, te; kwargs...)) && return false
# pass te down so this throws an error if error=:error
# if error is not `:error` and is_point was false -> return false, otherwise continue
(!is_point(G, e; error=error, kwargs...)) && return false
end
return is_vector(next_trait(t), G, identity_element(G), X, te, false; kwargs...)
return is_vector(
next_trait(t),
G,
identity_element(G),
X,
false;
error=error,
kwargs...,
)
end

@doc raw"""
Expand Down
10 changes: 4 additions & 6 deletions src/manifolds/MetricManifold.jl
Original file line number Diff line number Diff line change
Expand Up @@ -426,23 +426,21 @@ is_default_metric(::AbstractManifold, ::AbstractMetric) = false
function is_point(
::TraitList{IsMetricManifold},
M::MetricManifold{𝔽,TM,G},
p,
te::Bool=false;
p;
kwargs...,
) where {𝔽,G<:AbstractMetric,TM<:AbstractManifold}
return is_point(M.manifold, p, te; kwargs...)
return is_point(M.manifold, p; kwargs...)
end

function is_vector(
::TraitList{IsMetricManifold},
M::MetricManifold{𝔽,TM,G},
p,
X,
te::Bool=false,
cbp=true;
cbp::Bool=true;
kwargs...,
) where {𝔽,G<:AbstractMetric,TM<:AbstractManifold}
return is_vector(M.manifold, p, X, te, cbp; kwargs...)
return is_vector(M.manifold, p, X, cbp; kwargs...)
end

@doc raw"""
Expand Down
2 changes: 1 addition & 1 deletion test/groups/circle_group.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ using Manifolds:
@test !is_point(G, Identity(AdditionOperation()))
ef = Identity(AdditionOperation())
@test_throws DomainError is_point(G, ef; error=:error)
@test_throws DomainError is_vector(G, ef, X, true; check_base_point=true)
@test_throws DomainError is_vector(G, ef, X, true; error=:error)
end

@testset "scalar points" begin
Expand Down
4 changes: 2 additions & 2 deletions test/groups/general_linear.jl
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ using NLsolve
@test_throws DomainError is_point(G, Float64[0 0 0; 0 1 1; 1 1 1]; error=:error)
@test is_point(G, Float64[0 0 1; 0 1 1; 1 1 1]; error=:error)
@test is_point(G, Identity(G); error=:error)
@test_throws ManifoldDomainError is_vector(
@test_throws DomainError is_vector(
G,
Float64[0 1 1; 0 1 1; 1 0 0],
randn(3, 3);
Expand Down Expand Up @@ -151,7 +151,7 @@ using NLsolve
Float64[0 0 0; 0 1 1; 1 1 1];
error=:error,
)
@test_throws ManifoldDomainError is_vector(
@test_throws DomainError is_vector(
G,
ComplexF64[im im; im im],
randn(ComplexF64, 2, 2);
Expand Down
6 changes: 3 additions & 3 deletions test/groups/general_unitary_groups.jl
Original file line number Diff line number Diff line change
Expand Up @@ -160,14 +160,14 @@ include("group_utils.jl")
Xe = ones(2, 2)
X = project(SU2, q, Xe)
@test is_vector(SU2, q, X)
@test_throws ManifoldDomainError is_vector(SU2, p, X, true; error=:error) # base point wrong
@test_throws DomainError is_vector(SU2, p, X, true; error=:error) # base point wrong
@test_throws DomainError is_vector(SU2, q, Xe, true; error=:error) # Xe not skew hermitian
@test_throws DomainError is_vector(
SU2,
Identity(AdditionOperation()),
Xe,
true,
true,
true;
error=:error,
) # base point wrong
e = Identity(MultiplicationOperation())
@test_throws DomainError is_vector(SU2, e, Xe, true; error=:error) # Xe not skew hermitian
Expand Down
4 changes: 2 additions & 2 deletions test/groups/special_linear.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ using NLsolve
@test_throws DomainError is_point(G, Float64[1 3 3; 1 1 2; 1 2 3]; error=:error)
@test is_point(G, Float64[1 1 1; 2 2 1; 2 3 3]; error=:error)
@test is_point(G, Identity(G); error=:error)
@test_throws ManifoldDomainError is_vector(
@test_throws DomainError is_vector(
G,
Float64[2 3 2; 3 1 2; 1 1 1],
randn(3, 3);
Expand Down Expand Up @@ -136,7 +136,7 @@ using NLsolve
@test_throws DomainError is_point(G, ComplexF64[1 im; im 1]; error=:error)
@test is_point(G, ComplexF64[im 1; -2 im]; error=:error)
@test is_point(G, Identity(G); error=:error)
@test_throws ManifoldDomainError is_vector(
@test_throws DomainError is_vector(
G,
ComplexF64[-1+im -1; -im 1],
ComplexF64[1-im 1+im; 1 -1+im];
Expand Down
2 changes: 2 additions & 0 deletions test/groups/special_orthogonal.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
include("../utils.jl")
include("group_utils.jl")

using Manifolds: LeftForwardAction, RightBackwardAction

@testset "Special Orthogonal group" begin
for n in [2, 3]
G = SpecialOrthogonal(n)
Expand Down
2 changes: 1 addition & 1 deletion test/manifolds/centered_matrices.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ include("../utils.jl")
@test_throws DomainError is_point(M, D; error=:error)
@test check_vector(M, A, A) === nothing
@test_throws DomainError is_vector(M, A, D; error=:error)
@test_throws ManifoldDomainError is_vector(M, D, A; error=:error)
@test_throws DomainError is_vector(M, D, A; error=:error)
@test_throws ManifoldDomainError is_vector(M, A, B; error=:error)
@test manifold_dimension(M) == 4
@test A == project!(M, A, A)
Expand Down
4 changes: 2 additions & 2 deletions test/manifolds/essential_manifold.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ include("../utils.jl")
@test_throws DomainError is_vector(
M,
p1,
[0.0 0.0 0.0; 0.0 0.0 0.0; 0.0 0.0 0.0],
true,
[0.0 0.0 0.0; 0.0 0.0 0.0; 0.0 0.0 0.0];
error=:error,
)
@test !is_vector(M, np1, [0.0 0.0 0.0; 0.0 0.0 0.0; 0.0 0.0 0.0])
@test !is_vector(M, p1, p2)
Expand Down
2 changes: 1 addition & 1 deletion test/manifolds/fixed_rank.jl
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ include("../utils.jl")
UMVTVector(zeros(2, 1), zeros(1, 2), zeros(2, 2)),
)
@test !is_vector(M, SVDMPoint([1.0 0.0; 0.0 0.0], 2), X)
@test_throws ManifoldDomainError is_vector(
@test_throws DomainError is_vector(
M,
SVDMPoint([1.0 0.0; 0.0 0.0], 2),
X;
Expand Down
4 changes: 2 additions & 2 deletions test/manifolds/grassmann.jl
Original file line number Diff line number Diff line change
Expand Up @@ -177,8 +177,8 @@ include("../utils.jl")
@test_throws ManifoldDomainError is_vector(
M,
[1.0 0.0; 0.0 1.0; 0.0 0.0],
ones(3, 2),
true,
ones(3, 2);
error=:error,
)
@test is_vector(
M,
Expand Down
2 changes: 1 addition & 1 deletion test/manifolds/hyperbolic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ include("../utils.jl")
@test_throws DomainError is_point(M, [2.0, 0.0, 0.0]; error=:error)
@test !is_point(M, [2.0, 0.0, 0.0])
@test !is_vector(M, [1.0, 0.0, 0.0], [1.0, 0.0, 0.0])
@test_throws ManifoldDomainError is_vector(
@test_throws DomainError is_vector(
M,
[1.0, 0.0, 0.0],
[1.0, 0.0, 0.0];
Expand Down
4 changes: 2 additions & 2 deletions test/manifolds/multinomial_matrices.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ include("../utils.jl")
@test_throws CompositeManifoldError{ComponentManifoldError{Int64,DomainError}} is_vector(
M,
p2,
[-1.0, 0.0, 0.0],
true,
[-1.0, 0.0, 0.0];
error=:error,
)
@test !is_vector(M, p2, [-1.0, 0.0, 0.0])
@test_throws DomainError is_vector(M, p, [-1.0, 0.0, 0.0]; error=:error)
Expand Down
4 changes: 2 additions & 2 deletions test/manifolds/oblique.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ include("../utils.jl")
@test_throws CompositeManifoldError{ComponentManifoldError{Int64,DomainError}} is_vector(
M,
p2,
[0.0, 0.0, 0.0],
true,
[0.0, 0.0, 0.0];
error=:error,
)
@test !is_vector(M, p2, [0.0, 0.0, 0.0])
@test_throws DomainError is_vector(M, p, [0.0, 0.0, 0.0]; error=:error) # p wrong
Expand Down
2 changes: 1 addition & 1 deletion test/manifolds/probability_simplex.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ include("../utils.jl")
@test !is_flat(M)
@test is_vector(M, p, X)
@test is_vector(M, p, Y)
@test_throws ManifoldDomainError is_vector(M, p .+ 1, X; error=:error)
@test_throws DomainError is_vector(M, p .+ 1, X; error=:error)
@test_throws ManifoldDomainError is_vector(M, p, zeros(4); error=:error)
@test_throws DomainError is_vector(M, p, Y .+ 1; error=:error)

Expand Down
2 changes: 1 addition & 1 deletion test/manifolds/product_manifold.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ using RecursiveArrayTools: ArrayPartition
# test that arrays are not points
@test_throws DomainError is_point(Mse, [1, 2]; error=:error)
@test check_point(Mse, [1, 2]) isa DomainError
@test_throws DomainError is_vector(Mse, 1, [1, 2], true; check_base_point=false)
@test_throws DomainError is_vector(Mse, 1, [1, 2]; error=:error, check_base_point=false)
@test check_vector(Mse, 1, [1, 2]; check_base_point=false) isa DomainError
#default fallbacks for check_size, Product not working with Arrays
@test Manifolds.check_size(Mse, zeros(2)) isa DomainError
Expand Down
6 changes: 3 additions & 3 deletions test/manifolds/skewhermitian.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,13 @@ end
@test_throws DomainError is_point(M, D; error=:error)
@test check_vector(M, B_skewsym, B_skewsym) === nothing
@test_throws DomainError is_vector(M, B_skewsym, A; error=:error)
@test_throws ManifoldDomainError is_vector(M, A, B_skewsym; error=:error)
@test_throws DomainError is_vector(M, A, B_skewsym; error=:error)
@test_throws DomainError is_vector(M, B_skewsym, D; error=:error)
@test_throws ManifoldDomainError is_vector(
M,
B_skewsym,
1 * im * zero_vector(M, B_skewsym),
true,
1 * im * zero_vector(M, B_skewsym);
error=:error,
)
@test manifold_dimension(M) == 3
@test manifold_dimension(M_complex) == 9
Expand Down
2 changes: 1 addition & 1 deletion test/manifolds/sphere_symmetric_matrices.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ include("../utils.jl")
@test_throws ManifoldDomainError is_vector(M, A, B; error=:error)
@test_throws ManifoldDomainError is_vector(M, A, C; error=:error)
@test_throws ManifoldDomainError is_vector(M, A, D; error=:error)
@test_throws ManifoldDomainError is_vector(M, D, A; error=:error)
@test_throws DomainError is_vector(M, D, A; error=:error)
@test_throws ManifoldDomainError is_vector(M, A, E; error=:error)
@test_throws DomainError is_vector(M, J, K; error=:error)
@test manifold_dimension(M) == 5
Expand Down
4 changes: 2 additions & 2 deletions test/manifolds/stiefel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ include("../utils.jl")
@test !is_point(M, 2 * x)
@test_throws DomainError !is_point(M, 2 * x; error=:error)
@test !is_vector(M, 2 * x, v)
@test_throws ManifoldDomainError !is_vector(M, 2 * x, v; error=:error)
@test_throws DomainError !is_vector(M, 2 * x, v; error=:error)
@test !is_vector(M, x, y)
@test_throws DomainError is_vector(M, x, y; error=:error)
test_manifold(
Expand Down Expand Up @@ -227,7 +227,7 @@ include("../utils.jl")
@test !is_point(M, 2 * x)
@test_throws DomainError !is_point(M, 2 * x; error=:error)
@test !is_vector(M, 2 * x, v)
@test_throws ManifoldDomainError !is_vector(M, 2 * x, v; error=:error)
@test_throws DomainError !is_vector(M, 2 * x, v; error=:error)
@test !is_vector(M, x, y)
@test_throws DomainError is_vector(M, x, y; error=:error)
test_manifold(
Expand Down
6 changes: 3 additions & 3 deletions test/manifolds/symmetric.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,13 @@ include("../utils.jl")
@test_throws ManifoldDomainError is_point(M, D; error=:error) #embedding changes type
@test check_vector(M, B_sym, B_sym) === nothing
@test_throws DomainError is_vector(M, B_sym, A; error=:error)
@test_throws ManifoldDomainError is_vector(M, A, B_sym; error=:error)
@test_throws DomainError is_vector(M, A, B_sym; error=:error)
@test_throws ManifoldDomainError is_vector(M, B_sym, D; error=:error)
@test_throws ManifoldDomainError is_vector(
M,
B_sym,
1 * im * zero_vector(M, B_sym),
true,
1 * im * zero_vector(M, B_sym);
error=:error,
)
@test manifold_dimension(M) == 6
@test manifold_dimension(M_complex) == 9
Expand Down
4 changes: 2 additions & 2 deletions test/manifolds/symplectic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -190,9 +190,9 @@ using ManifoldDiff
@testset "Generate random points/tangent vectors" begin
M_big = Symplectic(20)
p_big = rand(M_big)
@test is_point(M_big, p_big, true; atol=1.0e-12)
@test is_point(M_big, p_big; error=:error, atol=1.0e-12)
X_big = rand(M_big; vector_at=p_big)
@test is_vector(M_big, p_big, X_big, true; atol=1.0e-12)
@test is_vector(M_big, p_big, X_big; error=:error, atol=1.0e-12)
end
@testset "test_manifold(Symplectic(6), ...)" begin
@testset "Type $(Matrix{Float64})" begin
Expand Down

0 comments on commit 63f5364

Please sign in to comment.