Skip to content

Commit

Permalink
some fixes for ManifoldsBase 0.15.6
Browse files Browse the repository at this point in the history
  • Loading branch information
mateuszbaran committed Dec 16, 2023
1 parent e7861d9 commit b269660
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 98 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ HybridArrays = "0.4"
Kronecker = "0.4, 0.5"
LinearAlgebra = "1.6"
ManifoldDiff = "0.3.7"
ManifoldsBase = "0.15.0"
ManifoldsBase = "0.15.6"
Markdown = "1.6"
MatrixEquations = "2.2"
OrdinaryDiffEq = "6.31"
Expand Down
8 changes: 8 additions & 0 deletions src/Manifolds.jl
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,7 @@ using ManifoldsBase:
ℝ,
ℂ,
ℍ,
AbstractApproximationMethod,
AbstractBasis,
AbstractDecoratorManifold,
AbstractInverseRetractionMethod,
Expand Down Expand Up @@ -225,20 +226,26 @@ using ManifoldsBase:
CotangentSpaceType,
CoTFVector,
CoTVector,
CyclicProximalPointEstimation,
DefaultBasis,
DefaultOrthogonalBasis,
DefaultOrthonormalBasis,
DefaultOrDiagonalizingBasis,
DiagonalizingBasisData,
DiagonalizingOrthonormalBasis,
DifferentiatedRetractionVectorTransport,
EfficientEstimator,
EmbeddedManifold,
EmptyTrait,
EuclideanMetric,
ExponentialRetraction,
ExtrinsicEstimation,
Fiber,
FiberType,
FVector,
GeodesicInterpolation,
GeodesicInterpolationWithinRadius,
GradientDescentEstimation,
InverseProductRetraction,
IsIsometricEmbeddedManifold,
IsEmbeddedManifold,
Expand Down Expand Up @@ -295,6 +302,7 @@ using ManifoldsBase:
VectorSpaceFiber,
VectorSpaceType,
VeeOrthogonalBasis,
WeiszfeldEstimation,
@invoke_maker,
_euclidean_basis_vector,
combine_allocation_promotion_functions,
Expand Down
99 changes: 2 additions & 97 deletions src/statistics.jl
Original file line number Diff line number Diff line change
@@ -1,107 +1,12 @@
"""
AbstractEstimationMethod
Abstract type for defining statistical estimation methods.
Deprecated alias for `AbstractApproximationMethod`
"""
abstract type AbstractEstimationMethod end

"""
GradientDescentEstimation <: AbstractEstimationMethod
Method for estimation using gradient descent.
"""
struct GradientDescentEstimation <: AbstractEstimationMethod end

"""
CyclicProximalPointEstimation <: AbstractEstimationMethod
Method for estimation using the cyclic proximal point technique.
"""
struct CyclicProximalPointEstimation <: AbstractEstimationMethod end

"""
ExtrinsicEstimation <: AbstractEstimationMethod
Method for estimation in the ambient space and projecting to the manifold.
For [`mean`](@ref) estimation, [`GeodesicInterpolation`](@ref) is used for mean estimation
in the ambient space.
"""
struct ExtrinsicEstimation <: AbstractEstimationMethod end

"""
WeiszfeldEstimation <: AbstractEstimationMethod
Method for estimation using the Weiszfeld algorithm for the [`median`](@ref)
"""
struct WeiszfeldEstimation <: AbstractEstimationMethod end
const AbstractEstimationMethod = AbstractApproximationMethod

_unit_weights(n::Int) = StatsBase.UnitWeights{Float64}(n)

@doc raw"""
GeodesicInterpolation <: AbstractEstimationMethod
Repeated weighted geodesic interpolation method for estimating the Riemannian
center of mass.
The algorithm proceeds with the following simple online update:
```math
\begin{aligned}
μ_1 &= x_1\\
t_k &= \frac{w_k}{\sum_{i=1}^k w_i}\\
μ_{k} &= γ_{μ_{k-1}}(x_k; t_k),
\end{aligned}
```
where $x_k$ are points, $w_k$ are weights, $μ_k$ is the $k$th estimate of the
mean, and $γ_x(y; t)$ is the point at time $t$ along the
[`shortest_geodesic`](https://juliamanifolds.github.io/ManifoldsBase.jl/stable/functions.html#ManifoldsBase.shortest_geodesic-Tuple{AbstractManifold,%20Any,%20Any})
between points $x,y ∈ \mathcal M$. The algorithm
terminates when all $x_k$ have been considered. In the [`Euclidean`](@ref) case,
this exactly computes the weighted mean.
The algorithm has been shown to converge asymptotically with the sample size for
the following manifolds equipped with their default metrics when all sampled
points are in an open geodesic ball about the mean with corresponding radius
(see [`GeodesicInterpolationWithinRadius`](@ref)):
* All simply connected complete Riemannian manifolds with non-positive sectional
curvature at radius $∞$ [ChengHoSalehianVemuri:2016](@cite), in particular:
+ [`Euclidean`](@ref)
+ [`SymmetricPositiveDefinite`](@ref) [HoChengSalehianVemuri:2013](@cite)
* Other manifolds:
+ [`Sphere`](@ref): $\frac{π}{2}$ [SalehianEtAl:2015](@cite)
+ [`Grassmann`](@ref): $\frac{π}{4}$ [ChakrabortyVemuri:2015](@cite)
+ [`Stiefel`](@ref)/[`Rotations`](@ref): $\frac{π}{2 \sqrt 2}$ [ChakrabortyVemuri:2019](@cite)
For online variance computation, the algorithm additionally uses an analogous
recursion to the weighted Welford algorithm [West:1979](@cite).
"""
struct GeodesicInterpolation <: AbstractEstimationMethod end

"""
GeodesicInterpolationWithinRadius{T} <: AbstractEstimationMethod
Estimation of Riemannian center of mass using [`GeodesicInterpolation`](@ref)
with fallback to [`GradientDescentEstimation`](@ref) if any points are outside of a
geodesic ball of specified `radius` around the mean.
# Constructor
GeodesicInterpolationWithinRadius(radius)
"""
struct GeodesicInterpolationWithinRadius{T} <: AbstractEstimationMethod
radius::T

function GeodesicInterpolationWithinRadius(radius::T) where {T}
radius > 0 && return new{T}(radius)
return throw(
DomainError("The radius must be strictly postive, received $(radius)."),
)
end
end

function Base.show(io::IO, method::GeodesicInterpolationWithinRadius)
return print(io, "GeodesicInterpolationWithinRadius($(method.radius))")
end
Expand Down

0 comments on commit b269660

Please sign in to comment.