-
Notifications
You must be signed in to change notification settings - Fork 56
/
GrassmannStiefel.jl
393 lines (318 loc) · 12.7 KB
/
GrassmannStiefel.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
#
# Default implementation for the matrix type, i.e. as congruence class Stiefel matrices
#
"""
StiefelPoint <: AbstractManifoldPoint
A point on a [`Stiefel`](@ref) manifold.
This point is mainly used for representing points on the [`Grassmann`](@ref) where this
is also the default representation and hence equivalent to using `AbstractMatrices` thereon.
they can also used be used as points on Stiefel.
"""
struct StiefelPoint{T<:AbstractMatrix} <: AbstractManifoldPoint
value::T
end
"""
StiefelTVector <: TVector
A tangent vector on the [`Grassmann`](@ref) manifold represented by a tangent vector from
the tangent space of a corresponding point from the [`Stiefel`](@ref) manifold,
see [`StiefelPoint`](@ref).
This is the default representation so is can be used interchangeably with just abstract matrices.
"""
struct StiefelTVector{T<:AbstractMatrix} <: AbstractManifoldPoint
value::T
end
ManifoldsBase.@manifold_element_forwards StiefelPoint value
ManifoldsBase.@manifold_vector_forwards StiefelTVector value
ManifoldsBase.@default_manifold_fallbacks Stiefel StiefelPoint StiefelTVector value value
ManifoldsBase.@default_manifold_fallbacks (Stiefel{<:Any,ℝ}) StiefelPoint StiefelTVector value value
ManifoldsBase.@default_manifold_fallbacks Grassmann StiefelPoint StiefelTVector value value
function default_vector_transport_method(::Grassmann, ::Type{<:AbstractArray})
return ProjectionTransport()
end
default_vector_transport_method(::Grassmann, ::Type{<:StiefelPoint}) = ProjectionTransport()
@doc raw"""
distance(M::Grassmann, p, q)
Compute the Riemannian distance on [`Grassmann`](@ref) manifold `M`$= \mathrm{Gr}(n,k)$.
The distance is given by
````math
d_{\mathrm{Gr}(n,k)}(p,q) = \operatorname{norm}(\log_p(q)).
````
"""
function distance(::Grassmann, p, q)
z = p' * q
S = svd(q / z - p).S
return norm(map(atan, S))
end
embed(::Grassmann, p) = p
embed(::Grassmann, p, X) = X
embed!(::Grassmann, q, p) = copyto!(q, p)
embed!(::Grassmann, Y, p, X) = copyto!(Y, X)
embed!(::Grassmann, q, p::StiefelPoint) = copyto!(q, p.value)
embed!(::Grassmann, Y, p::StiefelPoint, X::StiefelTVector) = copyto!(Y, X.value)
embed(::Grassmann, p::StiefelPoint) = p.value
embed(::Grassmann, p::StiefelPoint, X::StiefelTVector) = X.value
embed!(::Stiefel, q, p::StiefelPoint) = copyto!(q, p.value)
embed!(::Stiefel, Y, p::StiefelPoint, X::StiefelTVector) = copyto!(Y, X.value)
embed(::Stiefel, p::StiefelPoint) = p.value
embed(::Stiefel, p::StiefelPoint, X::StiefelTVector) = X.value
@doc raw"""
exp(M::Grassmann, p, X)
Compute the exponential map on the [`Grassmann`](@ref) `M`$= \mathrm{Gr}(n,k)$ starting in
`p` with tangent vector (direction) `X`. Let $X = USV$ denote the SVD decomposition of $X$.
Then the exponential map is written using
````math
z = p V\cos(S)V^\mathrm{H} + U\sin(S)V^\mathrm{H},
````
where $\cdot^{\mathrm{H}}$ denotes the complex conjugate transposed or Hermitian and the
cosine and sine are applied element wise to the diagonal entries of $S$. A final QR
decomposition $z=QR$ is performed for numerical stability reasons, yielding the result as
````math
\exp_p X = Q.
````
"""
exp(::Grassmann, ::Any...)
function exp!(M::Grassmann, q, p, X)
norm(M, p, X) ≈ 0 && return copyto!(q, p)
d = svd(X)
z = (p * (d.V .* cos.(d.S')) + d.U .* sin.(d.S')) * d.Vt
return copyto!(q, Array(qr(z).Q))
end
function get_embedding(::Grassmann{TypeParameter{Tuple{n,k}},𝔽}) where {n,k,𝔽}
return Stiefel(n, k, 𝔽)
end
function get_embedding(M::Grassmann{Tuple{Int,Int},𝔽}) where {𝔽}
n, k = get_parameter(M.size)
return Stiefel(n, k, 𝔽; parameter=:field)
end
@doc raw"""
inner(M::Grassmann, p, X, Y)
Compute the inner product for two tangent vectors `X`, `Y` from the tangent space
of `p` on the [`Grassmann`](@ref) manifold `M`. The formula reads
````math
g_p(X,Y) = \operatorname{tr}(X^{\mathrm{H}}Y),
````
where $\cdot^{\mathrm{H}}$ denotes the complex conjugate transposed or Hermitian.
"""
inner(::Grassmann, p, X, Y) = dot(X, Y)
@doc raw"""
inverse_retract(M::Grassmann, p, q, ::PolarInverseRetraction)
Compute the inverse retraction for the [`PolarRetraction`](https://juliamanifolds.github.io/ManifoldsBase.jl/stable/retractions.html#ManifoldsBase.PolarRetraction), on the
[`Grassmann`](@ref) manifold `M`, i.e.,
````math
\operatorname{retr}_p^{-1}q = q*(p^\mathrm{H}q)^{-1} - p,
````
where $\cdot^{\mathrm{H}}$ denotes the complex conjugate transposed or Hermitian.
"""
inverse_retract(::Grassmann, ::Any, ::Any, ::PolarInverseRetraction)
function inverse_retract_polar!(::Grassmann, X, p, q)
X .= q / (p' * q) .- p
return X
end
@doc raw"""
inverse_retract(M, p, q, ::QRInverseRetraction)
Compute the inverse retraction for the [`QRRetraction`](https://juliamanifolds.github.io/ManifoldsBase.jl/stable/retractions.html#ManifoldsBase.QRRetraction), on the
[`Grassmann`](@ref) manifold `M`, i.e.,
````math
\operatorname{retr}_p^{-1}q = q(p^\mathrm{H}q)^{-1} - p,
````
where $\cdot^{\mathrm{H}}$ denotes the complex conjugate transposed or Hermitian.
"""
inverse_retract(::Grassmann, ::Any, ::Any, ::QRInverseRetraction)
function inverse_retract_qr!(::Grassmann, X, p, q)
X .= q / (p' * q) .- p
return X
end
@doc raw"""
log(M::Grassmann, p, q)
Compute the logarithmic map on the [`Grassmann`](@ref) `M`$ = \mathcal M=\mathrm{Gr}(n,k)$,
i.e. the tangent vector `X` whose corresponding [`geodesic`](https://juliamanifolds.github.io/ManifoldsBase.jl/stable/functions.html#ManifoldsBase.geodesic-Tuple{AbstractManifold,%20Any,%20Any}) starting from `p`
reaches `q` after time 1 on `M`. The formula reads
````math
\log_p q = V\cdot \operatorname{atan}(S) \cdot U^\mathrm{H},
````
where $\cdot^{\mathrm{H}}$ denotes the complex conjugate transposed or Hermitian.
The matrices $U$ and $V$ are the unitary matrices, and $S$ is the diagonal matrix
containing the singular values of the SVD-decomposition
````math
USV = (q^\mathrm{H}p)^{-1} ( q^\mathrm{H} - q^\mathrm{H}pp^\mathrm{H}).
````
In this formula the $\operatorname{atan}$ is meant elementwise.
"""
log(::Grassmann, ::Any...)
function log!(M::Grassmann, X, p, q)
inverse_retract_polar!(M, X, p, q)
d = svd(X)
mul!(X, d.U, atan.(d.S) .* d.Vt)
return X
end
@doc raw"""
project(M::Grassmann, p)
Project `p` from the embedding onto the [`Grassmann`](@ref) `M`, i.e. compute `q`
as the polar decomposition of $p$ such that $q^{\mathrm{H}}q$ is the identity,
where $\cdot^{\mathrm{H}}$ denotes the Hermitian, i.e. complex conjugate transposed.
"""
project(::Grassmann, ::Any)
function project!(::Grassmann, q, p)
s = svd(p)
mul!(q, s.U, s.Vt)
return q
end
@doc raw"""
project(M::Grassmann, p, X)
Project the `n`-by-`k` `X` onto the tangent space of `p` on the [`Grassmann`](@ref) `M`,
which is computed by
````math
\operatorname{proj_p}(X) = X - pp^{\mathrm{H}}X,
````
where $\cdot^{\mathrm{H}}$ denotes the complex conjugate transposed or Hermitian.
"""
project(::Grassmann, ::Any...)
function project!(::Grassmann, Y, p, X)
copyto!(Y, X)
mul!(Y, p, p' * X, -1, 1)
return Y
end
@doc raw"""
rand(M::Grassmann; σ::Real=1.0, vector_at=nothing)
When `vector_at` is `nothing`, return a random point `p` on [`Grassmann`](@ref) manifold `M`
by generating a random (Gaussian) matrix with standard deviation `σ` in matching
size, which is orthonormal.
When `vector_at` is not `nothing`, return a (Gaussian) random vector from the tangent space
``T_p\mathrm{Gr}(n,k)`` with mean zero and standard deviation `σ` by projecting a random
Matrix onto the tangent space at `vector_at`.
"""
rand(M::Grassmann; σ::Real=1.0)
function Random.rand!(
rng::AbstractRNG,
M::Grassmann{<:Any,𝔽},
pX;
σ::Real=one(real(eltype(pX))),
vector_at=nothing,
) where {𝔽}
if vector_at === nothing
n, k = get_parameter(M.size)
V = σ * randn(rng, 𝔽 === ℝ ? Float64 : ComplexF64, (n, k))
pX .= qr(V).Q[:, 1:k]
else
Z = σ * randn(rng, eltype(pX), size(pX))
project!(M, pX, vector_at, Z)
pX ./= norm(pX)
end
return pX
end
@doc raw"""
representation_size(M::Grassmann)
Return the representation size or matrix dimension of a point on the [`Grassmann`](@ref)
`M`, i.e. $(n,k)$ for both the real-valued and the complex value case.
"""
representation_size(M::Grassmann) = get_parameter(M.size)
@doc raw"""
retract(M::Grassmann, p, X, ::PolarRetraction)
Compute the SVD-based retraction [`PolarRetraction`](https://juliamanifolds.github.io/ManifoldsBase.jl/stable/retractions.html#ManifoldsBase.PolarRetraction) on the
[`Grassmann`](@ref) `M`. With $USV = p + X$ the retraction reads
````math
\operatorname{retr}_p X = UV^\mathrm{H},
````
where $\cdot^{\mathrm{H}}$ denotes the complex conjugate transposed or Hermitian.
"""
retract(::Grassmann, ::Any, ::Any, ::PolarRetraction)
function retract_polar!(M::Grassmann, q, p, X, t::Number)
q .= p .+ t .* X
project!(M, q, q)
return q
end
@doc raw"""
retract(M::Grassmann, p, X, ::QRRetraction )
Compute the QR-based retraction [`QRRetraction`](https://juliamanifolds.github.io/ManifoldsBase.jl/stable/retractions.html#ManifoldsBase.QRRetraction) on the
[`Grassmann`](@ref) `M`. With $QR = p + X$ the retraction reads
````math
\operatorname{retr}_p X = QD,
````
where D is a $m × n$ matrix with
````math
D = \operatorname{diag}\left( \operatorname{sgn}\left(R_{ii}+\frac{1}{2}\right)_{i=1}^n \right).
````
"""
retract(::Grassmann, ::Any, ::Any, ::QRRetraction)
function retract_qr!(::Grassmann, q, p, X, t::Number)
q .= p .+ t .* X
qrfac = qr(q)
d = diag(qrfac.R)
q .= Array(qrfac.Q) .* sign.(transpose(d) .+ 1 // 2)
return q
end
@doc raw"""
riemannian_Hessian(M::Grassmann, p, G, H, X)
The Riemannian Hessian can be computed by adopting Eq. (6.6) [Nguyen:2023](@cite),
where we use for the [`EuclideanMetric`](https://juliamanifolds.github.io/ManifoldsBase.jl/stable/manifolds.html#ManifoldsBase.EuclideanMetric) ``α_0=α_1=1`` in their formula.
Let ``\nabla f(p)`` denote the Euclidean gradient `G`,
``\nabla^2 f(p)[X]`` the Euclidean Hessian `H`. Then the formula reads
```math
\operatorname{Hess}f(p)[X]
=
\operatorname{proj}_{T_p\mathcal M}\Bigl(
∇^2f(p)[X] - X p^{\mathrm{H}}∇f(p)
\Bigr).
```
Compared to Eq. (5.6) also the metric conversion simplifies to the identity.
"""
riemannian_Hessian(M::Grassmann, p, G, H, X)
function riemannian_Hessian!(M::Grassmann, Y, p, G, H, X)
project!(M, Y, p, H - X * p' * G)
return Y
end
@doc raw"""
riemann_tensor(::Grassmann{<:Any,ℝ}, p, X, Y, Z)
Compute the value of Riemann tensor on the real [`Grassmann`](@ref) manifold.
The formula reads [Rentmeesters:2011](@cite)
``R(X,Y)Z = (XY^\mathrm{T} - YX^\mathrm{T})Z + Z(Y^\mathrm{T}X - X^\mathrm{T}Y)``.
"""
riemann_tensor(::Grassmann{<:Any,ℝ}, p, X, Y, Z)
function riemann_tensor!(::Grassmann{<:Any,ℝ}, Xresult, p, X, Y, Z)
XYᵀ = X * Y'
YXᵀ = XYᵀ'
YᵀX = Y' * X
XᵀY = YᵀX'
Xresult .= (XYᵀ - YXᵀ) * Z .- Z * (YᵀX - XᵀY)
return Xresult
end
function Base.show(io::IO, ::Grassmann{TypeParameter{Tuple{n,k}},𝔽}) where {n,k,𝔽}
return print(io, "Grassmann($(n), $(k), $(𝔽))")
end
function Base.show(io::IO, M::Grassmann{Tuple{Int,Int},𝔽}) where {𝔽}
n, k = get_parameter(M.size)
return print(io, "Grassmann($(n), $(k), $(𝔽); parameter=:field)")
end
Base.show(io::IO, p::StiefelPoint) = print(io, "StiefelPoint($(p.value))")
Base.show(io::IO, X::StiefelTVector) = print(io, "StiefelTVector($(X.value))")
"""
uniform_distribution(M::Grassmann{<:Any,ℝ}, p)
Uniform distribution on given (real-valued) [`Grassmann`](@ref) `M`.
Specifically, this is the normalized Haar measure on `M`.
Generated points will be of similar type as `p`.
The implementation is based on Section 2.5.1 in [Chikuse:2003](@cite);
see also Theorem 2.2.2(iii) in [Chikuse:2003](@cite).
"""
function uniform_distribution(M::Grassmann{<:Any,ℝ}, p)
n, k = get_parameter(M.size)
μ = Distributions.Zeros(n, k)
σ = one(eltype(p))
Σ1 = Distributions.PDMats.ScalMat(n, σ)
Σ2 = Distributions.PDMats.ScalMat(k, σ)
d = MatrixNormal(μ, Σ1, Σ2)
return ProjectedPointDistribution(M, d, (M, q, p) -> (q .= svd(p).U), p)
end
@doc raw"""
vector_transport_to(M::Grassmann, p, X, q, ::ProjectionTransport)
compute the projection based transport on the [`Grassmann`](@ref) `M` by
interpreting `X` from the tangent space at `p` as a point in the embedding and
projecting it onto the tangent space at q.
"""
vector_transport_to(::Grassmann, ::Any, ::Any, ::Any, ::ProjectionTransport)
@doc raw"""
zero_vector(M::Grassmann, p)
Return the zero tangent vector from the tangent space at `p` on the [`Grassmann`](@ref) `M`,
which is given by a zero matrix the same size as `p`.
"""
zero_vector(::Grassmann, ::Any...)
zero_vector!(::Grassmann, X, p) = fill!(X, 0)