diff --git a/Project.toml b/Project.toml index 0bd91a22b8..51e6012a86 100644 --- a/Project.toml +++ b/Project.toml @@ -51,7 +51,7 @@ HybridArrays = "0.4" Kronecker = "0.4, 0.5" LinearAlgebra = "1.6" ManifoldDiff = "0.3.7" -ManifoldsBase = "0.15.0" +ManifoldsBase = "0.15.6" Markdown = "1.6" MatrixEquations = "2.2" OrdinaryDiffEq = "6.31" diff --git a/src/Manifolds.jl b/src/Manifolds.jl index 3ca4e4c428..7732cbf164 100644 --- a/src/Manifolds.jl +++ b/src/Manifolds.jl @@ -198,6 +198,7 @@ using ManifoldsBase: ℝ, ℂ, ℍ, + AbstractApproximationMethod, AbstractBasis, AbstractDecoratorManifold, AbstractInverseRetractionMethod, @@ -225,6 +226,7 @@ using ManifoldsBase: CotangentSpaceType, CoTFVector, CoTVector, + CyclicProximalPointEstimation, DefaultBasis, DefaultOrthogonalBasis, DefaultOrthonormalBasis, @@ -232,13 +234,18 @@ using ManifoldsBase: DiagonalizingBasisData, DiagonalizingOrthonormalBasis, DifferentiatedRetractionVectorTransport, + EfficientEstimator, EmbeddedManifold, EmptyTrait, EuclideanMetric, ExponentialRetraction, + ExtrinsicEstimation, Fiber, FiberType, FVector, + GeodesicInterpolation, + GeodesicInterpolationWithinRadius, + GradientDescentEstimation, InverseProductRetraction, IsIsometricEmbeddedManifold, IsEmbeddedManifold, @@ -295,6 +302,7 @@ using ManifoldsBase: VectorSpaceFiber, VectorSpaceType, VeeOrthogonalBasis, + WeiszfeldEstimation, @invoke_maker, _euclidean_basis_vector, combine_allocation_promotion_functions, diff --git a/src/statistics.jl b/src/statistics.jl index 7a1030e60f..6abd1ad754 100644 --- a/src/statistics.jl +++ b/src/statistics.jl @@ -1,107 +1,12 @@ """ AbstractEstimationMethod -Abstract type for defining statistical estimation methods. +Deprecated alias for `AbstractApproximationMethod` """ -abstract type AbstractEstimationMethod end - -""" - GradientDescentEstimation <: AbstractEstimationMethod - -Method for estimation using gradient descent. -""" -struct GradientDescentEstimation <: AbstractEstimationMethod end - -""" - CyclicProximalPointEstimation <: AbstractEstimationMethod - -Method for estimation using the cyclic proximal point technique. -""" -struct CyclicProximalPointEstimation <: AbstractEstimationMethod end - -""" - ExtrinsicEstimation <: AbstractEstimationMethod - -Method for estimation in the ambient space and projecting to the manifold. - -For [`mean`](@ref) estimation, [`GeodesicInterpolation`](@ref) is used for mean estimation -in the ambient space. -""" -struct ExtrinsicEstimation <: AbstractEstimationMethod end - -""" - WeiszfeldEstimation <: AbstractEstimationMethod - -Method for estimation using the Weiszfeld algorithm for the [`median`](@ref) -""" -struct WeiszfeldEstimation <: AbstractEstimationMethod end +const AbstractEstimationMethod = AbstractApproximationMethod _unit_weights(n::Int) = StatsBase.UnitWeights{Float64}(n) -@doc raw""" - GeodesicInterpolation <: AbstractEstimationMethod - -Repeated weighted geodesic interpolation method for estimating the Riemannian -center of mass. - -The algorithm proceeds with the following simple online update: - -```math -\begin{aligned} -μ_1 &= x_1\\ -t_k &= \frac{w_k}{\sum_{i=1}^k w_i}\\ -μ_{k} &= γ_{μ_{k-1}}(x_k; t_k), -\end{aligned} -``` - -where $x_k$ are points, $w_k$ are weights, $μ_k$ is the $k$th estimate of the -mean, and $γ_x(y; t)$ is the point at time $t$ along the -[`shortest_geodesic`](https://juliamanifolds.github.io/ManifoldsBase.jl/stable/functions.html#ManifoldsBase.shortest_geodesic-Tuple{AbstractManifold,%20Any,%20Any}) -between points $x,y ∈ \mathcal M$. The algorithm -terminates when all $x_k$ have been considered. In the [`Euclidean`](@ref) case, -this exactly computes the weighted mean. - -The algorithm has been shown to converge asymptotically with the sample size for -the following manifolds equipped with their default metrics when all sampled -points are in an open geodesic ball about the mean with corresponding radius -(see [`GeodesicInterpolationWithinRadius`](@ref)): - -* All simply connected complete Riemannian manifolds with non-positive sectional - curvature at radius $∞$ [ChengHoSalehianVemuri:2016](@cite), in particular: - + [`Euclidean`](@ref) - + [`SymmetricPositiveDefinite`](@ref) [HoChengSalehianVemuri:2013](@cite) -* Other manifolds: - + [`Sphere`](@ref): $\frac{π}{2}$ [SalehianEtAl:2015](@cite) - + [`Grassmann`](@ref): $\frac{π}{4}$ [ChakrabortyVemuri:2015](@cite) - + [`Stiefel`](@ref)/[`Rotations`](@ref): $\frac{π}{2 \sqrt 2}$ [ChakrabortyVemuri:2019](@cite) - -For online variance computation, the algorithm additionally uses an analogous -recursion to the weighted Welford algorithm [West:1979](@cite). -""" -struct GeodesicInterpolation <: AbstractEstimationMethod end - -""" - GeodesicInterpolationWithinRadius{T} <: AbstractEstimationMethod - -Estimation of Riemannian center of mass using [`GeodesicInterpolation`](@ref) -with fallback to [`GradientDescentEstimation`](@ref) if any points are outside of a -geodesic ball of specified `radius` around the mean. - -# Constructor - - GeodesicInterpolationWithinRadius(radius) -""" -struct GeodesicInterpolationWithinRadius{T} <: AbstractEstimationMethod - radius::T - - function GeodesicInterpolationWithinRadius(radius::T) where {T} - radius > 0 && return new{T}(radius) - return throw( - DomainError("The radius must be strictly postive, received $(radius)."), - ) - end -end - function Base.show(io::IO, method::GeodesicInterpolationWithinRadius) return print(io, "GeodesicInterpolationWithinRadius($(method.radius))") end