From 887164c626a6dfb85af2ab48dda66e0373ee934d Mon Sep 17 00:00:00 2001 From: Mateusz Baran Date: Mon, 3 Jul 2023 21:06:15 +0200 Subject: [PATCH 1/7] Add interface for optionally static size --- src/DefaultManifold.jl | 30 +++++++++++++++++++++++------- src/maintypes.jl | 28 ++++++++++++++++++++++++++++ test/default_manifold.jl | 5 +++++ 3 files changed, 56 insertions(+), 7 deletions(-) diff --git a/src/DefaultManifold.jl b/src/DefaultManifold.jl index 0aa37b75..3440196f 100644 --- a/src/DefaultManifold.jl +++ b/src/DefaultManifold.jl @@ -9,11 +9,20 @@ This manifold further illustrates how to type your manifold points and tangent v that the interface does not require this, but it might be handy in debugging and educative situations to verify correctness of involved variabes. """ -struct DefaultManifold{𝔽,T<:NTuple{N,Int} where {N}} <: AbstractManifold{𝔽} +struct DefaultManifold{𝔽,T<:AbstractManifoldSize} <: AbstractManifold{𝔽} size::T end -function DefaultManifold(n::Vararg{Int}; field = ℝ) - return DefaultManifold{field,typeof(n)}(n) +function ManifoldsBase.DefaultManifold( + n::Vararg{Int}; + field = ManifoldsBase.ℝ, + static = false, +) + if static + size = ManifoldsBase.StaticSize(n) + else + size = ManifoldsBase.RTSize(n) + end + return ManifoldsBase.DefaultManifold{field,typeof(size)}(size) end change_representer!(M::DefaultManifold, Y, ::EuclideanMetric, p, X) = copyto!(M, Y, p, X) @@ -112,7 +121,8 @@ is_flat(::DefaultManifold) = true log!(::DefaultManifold, Y, p, q) = (Y .= q .- p) function manifold_dimension(M::DefaultManifold{𝔽}) where {𝔽} - return length(M.size) == 0 ? 1 : *(M.size...) * real_dimension(𝔽) + size = getsize(M.size) + return length(size) == 0 ? 1 : *(size...) * real_dimension(𝔽) end number_system(::DefaultManifold{𝔽}) where {𝔽} = 𝔽 @@ -122,10 +132,16 @@ norm(::DefaultManifold, p, X) = norm(X) project!(::DefaultManifold, q, p) = copyto!(q, p) project!(::DefaultManifold, Y, p, X) = copyto!(Y, X) -representation_size(M::DefaultManifold) = M.size +representation_size(M::DefaultManifold) = getsize(M.size) -function Base.show(io::IO, M::DefaultManifold{𝔽}) where {𝔽} - return print(io, "DefaultManifold($(join(M.size, ", ")); field = $(𝔽))") +function Base.show(io::IO, M::DefaultManifold{𝔽,<:StaticSize}) where {𝔽} + return print( + io, + "DefaultManifold($(join(getsize(M.size), ", ")); field = $(𝔽), static = true)", + ) +end +function Base.show(io::IO, M::DefaultManifold{𝔽,<:RTSize}) where {𝔽} + return print(io, "DefaultManifold($(join(getsize(M.size), ", ")); field = $(𝔽))") end function parallel_transport_to!(::DefaultManifold, Y, p, X, q) diff --git a/src/maintypes.jl b/src/maintypes.jl index d933a6ce..98876f74 100644 --- a/src/maintypes.jl +++ b/src/maintypes.jl @@ -36,3 +36,31 @@ matrix internally, it is possible to use [`@manifold_element_forwards`](@ref) an [`@default_manifold_fallbacks`](@ref) to reduce implementation overhead. """ abstract type AbstractManifoldPoint end + +""" + abstract type AbstractManifoldSize end + +Abstract representation of manifold size. Can be either [`StaticSize`](@ref) or +[`RTSize`](@ref). +""" +abstract type AbstractManifoldSize end + +""" + StaticSize{T} + +Static size of a manifold. +""" +struct StaticSize{T} <: AbstractManifoldSize end +StaticSize(t::NTuple) = StaticSize{t}() + +""" + RTSize{TS<:NTuple{N,Int} where N} + +Runtime size of a manifold. +""" +struct RTSize{TS<:NTuple{N,Int} where {N}} <: AbstractManifoldSize + size::TS +end + +getsize(::StaticSize{T}) where {T} = T +getsize(S::RTSize) = S.size diff --git a/test/default_manifold.jl b/test/default_manifold.jl index 1d2bef1f..12a846f5 100644 --- a/test/default_manifold.jl +++ b/test/default_manifold.jl @@ -870,4 +870,9 @@ Base.size(x::MatrixVectorTransport) = (size(x.m, 2),) @test copy(M, p) === p @test copy(M, p, X) === X end + + @testset "static" begin + MS = ManifoldsBase.DefaultManifold(3; static = true) + @test (@inferred representation_size(MS)) == (3,) + end end From db4106642eb256d5dc893cf669011572eaf69095 Mon Sep 17 00:00:00 2001 From: Mateusz Baran Date: Mon, 3 Jul 2023 21:35:30 +0200 Subject: [PATCH 2/7] increase coverage --- test/default_manifold.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/test/default_manifold.jl b/test/default_manifold.jl index 12a846f5..dcc1fabc 100644 --- a/test/default_manifold.jl +++ b/test/default_manifold.jl @@ -874,5 +874,6 @@ Base.size(x::MatrixVectorTransport) = (size(x.m, 2),) @testset "static" begin MS = ManifoldsBase.DefaultManifold(3; static = true) @test (@inferred representation_size(MS)) == (3,) + @test repr(MS) == "DefaultManifold(3; field = ℝ, static = true)" end end From 5026b91a38916fb31ef45d5ede28b982084a2a0c Mon Sep 17 00:00:00 2001 From: Mateusz Baran Date: Tue, 4 Jul 2023 11:14:24 +0200 Subject: [PATCH 3/7] address review comments --- Project.toml | 2 +- src/DefaultManifold.jl | 42 +++++++++++++++++++++++++--------------- src/maintypes.jl | 30 ++++++++++++++-------------- test/default_manifold.jl | 7 ++++--- 4 files changed, 47 insertions(+), 34 deletions(-) diff --git a/Project.toml b/Project.toml index f37eb3c4..5055318a 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "ManifoldsBase" uuid = "3362f125-f0bb-47a3-aa74-596ffd7ef2fb" authors = ["Seth Axen ", "Mateusz Baran ", "Ronny Bergmann ", "Antoine Levitt "] -version = "0.14.7" +version = "0.14.8" [deps] LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" diff --git a/src/DefaultManifold.jl b/src/DefaultManifold.jl index 3440196f..9a0e09af 100644 --- a/src/DefaultManifold.jl +++ b/src/DefaultManifold.jl @@ -8,21 +8,31 @@ to build one's own manifold. It is a simplified/shortened variant of `Euclidean` This manifold further illustrates how to type your manifold points and tangent vectors. Note that the interface does not require this, but it might be handy in debugging and educative situations to verify correctness of involved variabes. + +# Constructor + + DefaultManifold(n::Int...; field = ℝ, parameter::Symbol = :field) + + +Arguments: + +- `n`: shape of array representing points on the manifold. +- `field`: field over which the manifold is defined. Either `ℝ`, `ℂ` or `ℍ`. +- `parameter`: whether a type parameter should be used to store `n`. By default size + is stored in a field. Value can either be `:field` or `:type`. """ -struct DefaultManifold{𝔽,T<:AbstractManifoldSize} <: AbstractManifold{𝔽} +struct DefaultManifold{𝔽,T<:AbstractManifoldParameter} <: AbstractManifold{𝔽} size::T end -function ManifoldsBase.DefaultManifold( - n::Vararg{Int}; - field = ManifoldsBase.ℝ, - static = false, -) - if static - size = ManifoldsBase.StaticSize(n) +function DefaultManifold(n::Vararg{Int}; field = ℝ, parameter::Symbol = :field) + if parameter === :field + size = FieldParameter(n) + elseif parameter === :type + size = TypeParameter(n) else - size = ManifoldsBase.RTSize(n) + throw(ArgumentError("Parameter can be either :field or :type. Given: $parameter")) end - return ManifoldsBase.DefaultManifold{field,typeof(size)}(size) + return DefaultManifold{field,typeof(size)}(size) end change_representer!(M::DefaultManifold, Y, ::EuclideanMetric, p, X) = copyto!(M, Y, p, X) @@ -121,7 +131,7 @@ is_flat(::DefaultManifold) = true log!(::DefaultManifold, Y, p, q) = (Y .= q .- p) function manifold_dimension(M::DefaultManifold{𝔽}) where {𝔽} - size = getsize(M.size) + size = get_parameter(M.size) return length(size) == 0 ? 1 : *(size...) * real_dimension(𝔽) end @@ -132,16 +142,16 @@ norm(::DefaultManifold, p, X) = norm(X) project!(::DefaultManifold, q, p) = copyto!(q, p) project!(::DefaultManifold, Y, p, X) = copyto!(Y, X) -representation_size(M::DefaultManifold) = getsize(M.size) +representation_size(M::DefaultManifold) = get_parameter(M.size) -function Base.show(io::IO, M::DefaultManifold{𝔽,<:StaticSize}) where {𝔽} +function Base.show(io::IO, M::DefaultManifold{𝔽,<:TypeParameter}) where {𝔽} return print( io, - "DefaultManifold($(join(getsize(M.size), ", ")); field = $(𝔽), static = true)", + "DefaultManifold($(join(get_parameter(M.size), ", ")); field = $(𝔽), parameter = :type)", ) end -function Base.show(io::IO, M::DefaultManifold{𝔽,<:RTSize}) where {𝔽} - return print(io, "DefaultManifold($(join(getsize(M.size), ", ")); field = $(𝔽))") +function Base.show(io::IO, M::DefaultManifold{𝔽,<:FieldParameter}) where {𝔽} + return print(io, "DefaultManifold($(join(get_parameter(M.size), ", ")); field = $(𝔽))") end function parallel_transport_to!(::DefaultManifold, Y, p, X, q) diff --git a/src/maintypes.jl b/src/maintypes.jl index 98876f74..f453843f 100644 --- a/src/maintypes.jl +++ b/src/maintypes.jl @@ -38,29 +38,31 @@ matrix internally, it is possible to use [`@manifold_element_forwards`](@ref) an abstract type AbstractManifoldPoint end """ - abstract type AbstractManifoldSize end + abstract type AbstractManifoldParameter end -Abstract representation of manifold size. Can be either [`StaticSize`](@ref) or -[`RTSize`](@ref). +Abstract representation of numeric parameters for a manifold type. Can be either +[`TypeParameter`](@ref) or [`FieldParameter`](@ref). """ -abstract type AbstractManifoldSize end +abstract type AbstractManifoldParameter end """ - StaticSize{T} + TypeParameter{T} -Static size of a manifold. +Represents numeric parameters of a manifold type as type parameters, allowing for static +specialization of methods. """ -struct StaticSize{T} <: AbstractManifoldSize end -StaticSize(t::NTuple) = StaticSize{t}() +struct TypeParameter{T} <: AbstractManifoldParameter end +TypeParameter(t::NTuple) = TypeParameter{t}() """ - RTSize{TS<:NTuple{N,Int} where N} + FieldParameter{TS<:NTuple{N,Int} where N} -Runtime size of a manifold. +Represents numeric parameters of a manifold type as values in a field, allowing for +less static specialization of methods and faster TTFX. """ -struct RTSize{TS<:NTuple{N,Int} where {N}} <: AbstractManifoldSize - size::TS +struct FieldParameter{TS<:NTuple{N,Int} where {N}} <: AbstractManifoldParameter + parameter::TS end -getsize(::StaticSize{T}) where {T} = T -getsize(S::RTSize) = S.size +get_parameter(::TypeParameter{T}) where {T} = T +get_parameter(P::FieldParameter) = P.parameter diff --git a/test/default_manifold.jl b/test/default_manifold.jl index dcc1fabc..9cc45a89 100644 --- a/test/default_manifold.jl +++ b/test/default_manifold.jl @@ -871,9 +871,10 @@ Base.size(x::MatrixVectorTransport) = (size(x.m, 2),) @test copy(M, p, X) === X end - @testset "static" begin - MS = ManifoldsBase.DefaultManifold(3; static = true) + @testset "static (size in type parameter)" begin + MS = ManifoldsBase.DefaultManifold(3; parameter = :type) @test (@inferred representation_size(MS)) == (3,) - @test repr(MS) == "DefaultManifold(3; field = ℝ, static = true)" + @test repr(MS) == "DefaultManifold(3; field = ℝ, parameter = :type)" + @test_throws ArgumentError ManifoldsBase.DefaultManifold(3; parameter = :foo) end end From 02acfd69b7664b7b0f29bff6f83a60e943b27d1e Mon Sep 17 00:00:00 2001 From: Mateusz Baran Date: Tue, 4 Jul 2023 11:20:08 +0200 Subject: [PATCH 4/7] use`prod` --- src/DefaultManifold.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/DefaultManifold.jl b/src/DefaultManifold.jl index 9a0e09af..7432a79d 100644 --- a/src/DefaultManifold.jl +++ b/src/DefaultManifold.jl @@ -132,7 +132,7 @@ log!(::DefaultManifold, Y, p, q) = (Y .= q .- p) function manifold_dimension(M::DefaultManifold{𝔽}) where {𝔽} size = get_parameter(M.size) - return length(size) == 0 ? 1 : *(size...) * real_dimension(𝔽) + return prod(size) * real_dimension(𝔽) end number_system(::DefaultManifold{𝔽}) where {𝔽} = 𝔽 From 87a6f4a88e0c1c23f84994218f6ce1c9d0e2e361 Mon Sep 17 00:00:00 2001 From: Mateusz Baran Date: Fri, 7 Jul 2023 16:40:56 +0200 Subject: [PATCH 5/7] simplify parameters --- src/DefaultManifold.jl | 6 +++--- src/maintypes.jl | 22 ++-------------------- 2 files changed, 5 insertions(+), 23 deletions(-) diff --git a/src/DefaultManifold.jl b/src/DefaultManifold.jl index 7432a79d..7dca4eea 100644 --- a/src/DefaultManifold.jl +++ b/src/DefaultManifold.jl @@ -21,12 +21,12 @@ Arguments: - `parameter`: whether a type parameter should be used to store `n`. By default size is stored in a field. Value can either be `:field` or `:type`. """ -struct DefaultManifold{𝔽,T<:AbstractManifoldParameter} <: AbstractManifold{𝔽} +struct DefaultManifold{𝔽,T} <: AbstractManifold{𝔽} size::T end function DefaultManifold(n::Vararg{Int}; field = ℝ, parameter::Symbol = :field) if parameter === :field - size = FieldParameter(n) + size = n elseif parameter === :type size = TypeParameter(n) else @@ -150,7 +150,7 @@ function Base.show(io::IO, M::DefaultManifold{𝔽,<:TypeParameter}) where {𝔽 "DefaultManifold($(join(get_parameter(M.size), ", ")); field = $(𝔽), parameter = :type)", ) end -function Base.show(io::IO, M::DefaultManifold{𝔽,<:FieldParameter}) where {𝔽} +function Base.show(io::IO, M::DefaultManifold{𝔽}) where {𝔽} return print(io, "DefaultManifold($(join(get_parameter(M.size), ", ")); field = $(𝔽))") end diff --git a/src/maintypes.jl b/src/maintypes.jl index f453843f..cd042547 100644 --- a/src/maintypes.jl +++ b/src/maintypes.jl @@ -37,32 +37,14 @@ matrix internally, it is possible to use [`@manifold_element_forwards`](@ref) an """ abstract type AbstractManifoldPoint end -""" - abstract type AbstractManifoldParameter end - -Abstract representation of numeric parameters for a manifold type. Can be either -[`TypeParameter`](@ref) or [`FieldParameter`](@ref). -""" -abstract type AbstractManifoldParameter end - """ TypeParameter{T} Represents numeric parameters of a manifold type as type parameters, allowing for static specialization of methods. """ -struct TypeParameter{T} <: AbstractManifoldParameter end +struct TypeParameter{T} end TypeParameter(t::NTuple) = TypeParameter{t}() -""" - FieldParameter{TS<:NTuple{N,Int} where N} - -Represents numeric parameters of a manifold type as values in a field, allowing for -less static specialization of methods and faster TTFX. -""" -struct FieldParameter{TS<:NTuple{N,Int} where {N}} <: AbstractManifoldParameter - parameter::TS -end - get_parameter(::TypeParameter{T}) where {T} = T -get_parameter(P::FieldParameter) = P.parameter +get_parameter(P) = P From 2646b01c28f6f12872cc4d5e4b31fc735e056a9f Mon Sep 17 00:00:00 2001 From: Mateusz Baran Date: Thu, 3 Aug 2023 11:31:01 +0200 Subject: [PATCH 6/7] Extract wrapping to a separate function --- src/DefaultManifold.jl | 10 ++-------- src/maintypes.jl | 17 +++++++++++++++++ 2 files changed, 19 insertions(+), 8 deletions(-) diff --git a/src/DefaultManifold.jl b/src/DefaultManifold.jl index 7dca4eea..c015d9a0 100644 --- a/src/DefaultManifold.jl +++ b/src/DefaultManifold.jl @@ -7,7 +7,7 @@ to build one's own manifold. It is a simplified/shortened variant of `Euclidean` This manifold further illustrates how to type your manifold points and tangent vectors. Note that the interface does not require this, but it might be handy in debugging and educative -situations to verify correctness of involved variabes. +situations to verify correctness of involved variables. # Constructor @@ -25,13 +25,7 @@ struct DefaultManifold{𝔽,T} <: AbstractManifold{𝔽} size::T end function DefaultManifold(n::Vararg{Int}; field = ℝ, parameter::Symbol = :field) - if parameter === :field - size = n - elseif parameter === :type - size = TypeParameter(n) - else - throw(ArgumentError("Parameter can be either :field or :type. Given: $parameter")) - end + size = wrap_type_parameter(parameter, n) return DefaultManifold{field,typeof(size)}(size) end diff --git a/src/maintypes.jl b/src/maintypes.jl index cd042547..f34f0559 100644 --- a/src/maintypes.jl +++ b/src/maintypes.jl @@ -48,3 +48,20 @@ TypeParameter(t::NTuple) = TypeParameter{t}() get_parameter(::TypeParameter{T}) where {T} = T get_parameter(P) = P + +""" + wrap_type_parameter(parameter::Symbol, data) + +Wrap `data` in `TypeParameter` if `parameter` is `:type` or return `data` unchanged +if `parameter` is `:field`. Intended for use in manifold constructors, see +[`DefaultManifold`](@ref) for an example. +""" +@inline function wrap_type_parameter(parameter::Symbol, data) + if parameter === :field + return data + elseif parameter === :type + TypeParameter(data) + else + throw(ArgumentError("Parameter can be either :field or :type. Given: $parameter")) + end +end From 9d68f5b76cbd3269a084c1cdaaacf126074983b6 Mon Sep 17 00:00:00 2001 From: Mateusz Baran Date: Thu, 3 Aug 2023 13:05:27 +0200 Subject: [PATCH 7/7] bump version --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 5055318a..a956ea6a 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "ManifoldsBase" uuid = "3362f125-f0bb-47a3-aa74-596ffd7ef2fb" authors = ["Seth Axen ", "Mateusz Baran ", "Ronny Bergmann ", "Antoine Levitt "] -version = "0.14.8" +version = "0.14.9" [deps] LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"