Skip to content

Commit

Permalink
adapt power manifold
Browse files Browse the repository at this point in the history
  • Loading branch information
mateuszbaran committed Oct 14, 2023
1 parent 3a1e4d6 commit 0b99b8e
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 11 deletions.
34 changes: 25 additions & 9 deletions src/manifolds/PowerManifold.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,13 @@ tangent space of the power manifold.
"""
struct PowerMetric <: AbstractMetric end

function PowerManifold(M::AbstractManifold{𝔽}, size::Integer...) where {𝔽}
return PowerManifold{𝔽,typeof(M),Tuple{size...},ArrayPowerRepresentation}(M)
function PowerManifold(
M::AbstractManifold{𝔽},
size::Integer...;
parameter::Symbol=:field,
) where {𝔽}
size_w = wrap_type_parameter(parameter, size)
return PowerManifold{𝔽,typeof(M),typeof(size_w),ArrayPowerRepresentation}(M, size_w)
end

"""
Expand Down Expand Up @@ -130,8 +135,9 @@ end
Return the manifold volume of an [`PowerManifold`](https://juliamanifolds.github.io/ManifoldsBase.jl/stable/manifolds.html#ManifoldsBase.PowerManifold) `M`.
"""
function manifold_volume(M::PowerManifold{𝔽,<:AbstractManifold,TSize}) where {𝔽,TSize}
return manifold_volume(M.manifold)^prod(size_to_tuple(TSize))
function manifold_volume(M::PowerManifold)
size = get_parameter(M.size)
return manifold_volume(M.manifold)^prod(size)
end

function Random.rand(rng::AbstractRNG, d::PowerFVectorDistribution)
Expand Down Expand Up @@ -210,8 +216,8 @@ function Base.view(
return _write(M, rep_size, p, I...)
end

function representation_size(M::PowerManifold{𝔽,<:AbstractManifold,TSize}) where {𝔽,TSize}
return (representation_size(M.manifold)..., size_to_tuple(TSize)...)
function representation_size(M::PowerManifold)
return (representation_size(M.manifold)..., get_parameter(M.size)...)
end

@doc raw"""
Expand Down Expand Up @@ -266,9 +272,19 @@ end

function Base.show(
io::IO,
M::PowerManifold{𝔽,TM,TSize,ArrayPowerRepresentation},
) where {𝔽,TM,TSize}
return print(io, "PowerManifold($(M.manifold), $(join(TSize.parameters, ", ")))")
M::PowerManifold{𝔽,TM,TypeParameter{TSize},ArrayPowerRepresentation},
) where {𝔽,TM<:AbstractManifold{𝔽},TSize}
return print(
io,
"PowerManifold($(M.manifold), $(join(TSize.parameters, ", ")), parameter=:type)",
)
end
function Base.show(
io::IO,
M::PowerManifold{𝔽,TM,<:Tuple,ArrayPowerRepresentation},
) where {𝔽,TM<:AbstractManifold{𝔽}}
size = get_parameter(M.size)
return print(io, "PowerManifold($(M.manifold), $(join(size, ", ")))")
end

Distributions.support(tvd::PowerFVectorDistribution) = FVectorSupport(tvd.type)
Expand Down
8 changes: 6 additions & 2 deletions src/manifolds/VectorFiber.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@


"""
TensorProductType(spaces::VectorSpaceType...)
Expand All @@ -13,7 +12,12 @@ end
TensorProductType(spaces::VectorSpaceType...) = TensorProductType{typeof(spaces)}(spaces)

function inner(B::CotangentSpace, p, X, Y)
return inner(B.manifold, B.point, sharp(B.manifold, B.point, X), sharp(B.manifold, B.point, Y))
return inner(
B.manifold,
B.point,
sharp(B.manifold, B.point, X),
sharp(B.manifold, B.point, Y),
)
end

function Base.show(io::IO, tpt::TensorProductType)
Expand Down
6 changes: 6 additions & 0 deletions test/manifolds/power_manifold.jl
Original file line number Diff line number Diff line change
Expand Up @@ -489,4 +489,10 @@ end
X = repeat([0.0, 1.0, 0.0], 1, 5)
@test volume_density(Ms1, p, X) volume_density(Ms, p[:, 1], X[:, 1])^5
end

@testset "Static type parameter" begin
Ms1s = PowerManifold(Ms, 5; parameter=:type)
@test sprint(show, "text/plain", Ms1s) ==
"PowerManifold(Sphere(2, ℝ), 5, parameter=:type)"
end
end

0 comments on commit 0b99b8e

Please sign in to comment.