-
Notifications
You must be signed in to change notification settings - Fork 56
/
projected_distribution.jl
129 lines (113 loc) · 3.49 KB
/
projected_distribution.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
"""
ProjectedPointDistribution(M::AbstractManifold, d, proj!, p)
Generates a random point in ambient space of `M` and projects it to `M`
using function `proj!`. Generated arrays are of type `TResult`, which can be
specified by providing the `p` argument.
"""
struct ProjectedPointDistribution{TResult,TM<:AbstractManifold,TD<:Distribution,TProj} <:
MPointDistribution{TM}
manifold::TM
distribution::TD
proj!::TProj
end
function ProjectedPointDistribution(
M::AbstractManifold,
d::Distribution,
proj!,
::TResult,
) where {TResult}
return ProjectedPointDistribution{TResult,typeof(M),typeof(d),typeof(proj!)}(
M,
d,
proj!,
)
end
"""
projected_distribution(M::AbstractManifold, d, [p=rand(d)])
Wrap the standard distribution `d` into a manifold-valued distribution. Generated
points will be of similar type to `p`. By default, the type is not changed.
"""
function projected_distribution(M::AbstractManifold, d, p=rand(d))
return ProjectedPointDistribution(M, d, project!, p)
end
function Random.rand(
rng::AbstractRNG,
d::ProjectedPointDistribution{TResult},
) where {TResult}
p = convert(TResult, rand(rng, d.distribution))
return d.proj!(d.manifold, p, p)
end
function Random.rand(
rng::AbstractRNG,
d::ProjectedPointDistribution{TResult},
n::Int,
) where {TResult}
ps = [convert(TResult, rand(rng, d.distribution)) for _ in 1:n]
map(p -> d.proj!(d.manifold, p, p), ps)
return ps
end
function Distributions._rand!(
rng::AbstractRNG,
d::ProjectedPointDistribution,
p::AbstractArray{<:Number},
)
Distributions._rand!(rng, d.distribution, p)
return d.proj!(d.manifold, p, p)
end
Distributions.support(d::ProjectedPointDistribution) = MPointSupport(d.manifold)
"""
ProjectedFVectorDistribution(type::VectorSpaceFiber, p, d, project!)
Generates a random vector from ambient space of manifold `type.manifold`
at point `p` and projects it to vector space of type `type` using function
`project!`, see [`project`](@ref) for documentation.
Generated arrays are of type `TResult`.
"""
struct ProjectedFVectorDistribution{
TResult,
TSpace<:VectorSpaceFiber,
TD<:Distribution,
TProj,
} <: FVectorDistribution{TSpace}
type::TSpace
distribution::TD
project!::TProj
end
function ProjectedFVectorDistribution(
type::VectorSpaceFiber,
d::Distribution,
project!,
::TResult,
) where {TResult}
return ProjectedFVectorDistribution{TResult,typeof(type),typeof(d),typeof(project!)}(
type,
d,
project!,
)
end
function Random.rand(
rng::AbstractRNG,
d::ProjectedFVectorDistribution{TResult},
) where {TResult}
X = convert(TResult, reshape(rand(rng, d.distribution), size(d.type.point)))
return d.project!(d.type.manifold, X, d.type.point, X)
end
function Distributions._rand!(
rng::AbstractRNG,
d::ProjectedFVectorDistribution,
X::AbstractArray{<:Number},
)
# calling _rand!(rng, d.d, v) doesn't work for all arrays types
return copyto!(X, rand(rng, d))
end
"""
normal_tvector_distribution(M::Euclidean, p, σ)
Normal distribution in ambient space with standard deviation `σ`
projected to tangent space at `p`.
"""
function normal_tvector_distribution(M::AbstractManifold, p, σ)
d = Distributions.MvNormal(zero(vec(p)), σ * I)
return ProjectedFVectorDistribution(TangentSpace(M, p), d, project!, p)
end
function Distributions.support(tvd::ProjectedFVectorDistribution)
return FVectorSupport(tvd.type)
end