Skip to content

Commit

Permalink
fix allocation on tangent spaces and fixed rank for random point/vect…
Browse files Browse the repository at this point in the history
…or generation.
  • Loading branch information
kellertuer committed Sep 22, 2023
1 parent b5cdb3c commit 0ed2117
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 26 deletions.
38 changes: 12 additions & 26 deletions src/manifolds/FixedRankMatrices.jl
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,17 @@ function allocate_result(
# vals are p and X, so we can use their fields to set up those of the UMVTVector
return UMVTVector(allocate(p.U, m, k), allocate(p.S, k, k), allocate(p.Vt, k, n))
end
function allocate_result(::FixedRankMatrices{m,n,k}, ::typeof(rand), p) where {m,n,k}

Check warning on line 130 in src/manifolds/FixedRankMatrices.jl

View check run for this annotation

Codecov / codecov/patch

src/manifolds/FixedRankMatrices.jl#L130

Added line #L130 was not covered by tests
# vals are p and X, so we can use their fields to set up those of the UMVTVector
return UMVTVector(allocate(p.U, m, k), allocate(p.S, k, k), allocate(p.Vt, k, n))

Check warning on line 132 in src/manifolds/FixedRankMatrices.jl

View check run for this annotation

Codecov / codecov/patch

src/manifolds/FixedRankMatrices.jl#L132

Added line #L132 was not covered by tests
end
function allocate_result(::FixedRankMatrices{m,n,k}, ::typeof(rand)) where {m,n,k}
return SVDMPoint(

Check warning on line 135 in src/manifolds/FixedRankMatrices.jl

View check run for this annotation

Codecov / codecov/patch

src/manifolds/FixedRankMatrices.jl#L134-L135

Added lines #L134 - L135 were not covered by tests
Matrix{Float64}(undef, m, k),
Vector{Float64}(undef, k),
Matrix{Float64}(undef, k, n),
)
end

Base.copy(v::UMVTVector) = UMVTVector(copy(v.U), copy(v.M), copy(v.Vt))

Expand Down Expand Up @@ -448,32 +459,7 @@ and the singular values are sampled uniformly at random.
If `vector_at` is not `nothing`, generate a random tangent vector in the tangent space of
the point `vector_at` on the `FixedRankMatrices` manifold `M`.
"""
function Random.rand(M::FixedRankMatrices; vector_at=nothing, kwargs...)
return rand(Random.default_rng(), M; vector_at=vector_at, kwargs...)
end
function Random.rand(
rng::AbstractRNG,
M::FixedRankMatrices{m,n,k};
vector_at=nothing,
kwargs...,
) where {m,n,k}
if vector_at === nothing
p = SVDMPoint(
Matrix{Float64}(undef, m, k),
Vector{Float64}(undef, k),
Matrix{Float64}(undef, k, n),
)
return rand!(rng, M, p; kwargs...)
else
X = UMVTVector(
Matrix{Float64}(undef, m, k),
Matrix{Float64}(undef, k, k),
Matrix{Float64}(undef, k, n),
)
return rand!(rng, M, X; vector_at, kwargs...)
end
end

Random.rand(M::FixedRankMatrices; vector_at=nothing, kwargs...)
function Random.rand!(
rng::AbstractRNG,
::FixedRankMatrices{m,n,k},
Expand Down
4 changes: 4 additions & 0 deletions src/manifolds/VectorBundle.jl
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,10 @@ See also [`VectorBundleInverseProductRetraction`](@ref).
"""
struct VectorBundleProductRetraction <: AbstractRetractionMethod end

function allocate_result(M::TangentSpaceAtPoint, ::typeof(rand))
return zero_vector(M.fiber.manifold, M.point)

Check warning on line 296 in src/manifolds/VectorBundle.jl

View check run for this annotation

Codecov / codecov/patch

src/manifolds/VectorBundle.jl#L295-L296

Added lines #L295 - L296 were not covered by tests
end

base_manifold(B::VectorBundleFibers) = base_manifold(B.manifold)
base_manifold(B::VectorSpaceAtPoint) = base_manifold(B.fiber)
base_manifold(B::VectorBundle) = base_manifold(B.manifold)
Expand Down

0 comments on commit 0ed2117

Please sign in to comment.