Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support and test in-place sampling #176

Merged
merged 5 commits into from
Jun 23, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "AbstractGPs"
uuid = "99985d1d-32ba-4be9-9821-2ec096f28918"
authors = ["JuliaGaussianProcesses Team"]
version = "0.3.5"
version = "0.3.6"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand Down
3 changes: 2 additions & 1 deletion docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -64,15 +64,16 @@ If you are building something on top of AbstractGPs, try to implement it in term

```@docs
rand
rand!
Comment on lines 66 to +67
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we actually replace this by Distributions._rand! instead?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess this would be reasonable if we want to follow the conventions in Distributions. I am just not sure if we want to do this, it feels weird if an internal method such as _rand! is part of the API (of course, this is a problem of Distributions but since we already define rand separately maybe we want to not follow it here either).

Copy link
Member

@theogf theogf Jun 22, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if we want to follow the conventions in Distributions

That seems like something we want to do no? Maybe we can add a reference to their API docs for better understanding?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But if we really want to follow their conventions we would/should not have to implement rand. However, I really think we should not do this since it is quite likely to lead to incorrect types (I checked and tests fail since the output type is incorrect for Float32 - of course, one could work around this by defining eltype(::FiniteGP) but I think the easier solution is to only perform in-place sampling if requested 🤷). So we do not follow the abstractions in Distributions even if we tell users to implement Distributions._rand!.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The other reason to like having an explicit implementation for rand is AD, since our AD tools really don't like mutation.

marginals
logpdf(::AbstractGPs.FiniteGP, ::AbstractVector{<:Real})
posterior(::AbstractGPs.FiniteGP, ::AbstractVector{<:Real})
mean(::AbstractGPs.FiniteGP)
var(::AbstractGPs.FiniteGP)

```

#### Optional methods

Default implementations are provided for these, but you may wish to specialise for performance.
```@docs
mean_and_var(::AbstractGPs.FiniteGP)
Expand Down
1 change: 1 addition & 0 deletions src/AbstractGPs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ using RecipesBase
using KernelFunctions: ColVecs, RowVecs

export GP,
rand!,
mean,
cov,
var,
Expand Down
37 changes: 37 additions & 0 deletions src/abstract_gp/finite_gp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,43 @@ Random.rand(f::FiniteGP, N::Int) = rand(Random.GLOBAL_RNG, f, N)
Random.rand(rng::AbstractRNG, f::FiniteGP) = vec(rand(rng, f, 1))
Random.rand(f::FiniteGP) = vec(rand(f, 1))

# in-place sampling
"""
rand!(rng::AbstractRNG, f::FiniteGP, y::AbstractVecOrMat{<:Real})

Obtain sample(s) from the marginals `f` using `rng` and write them to `y`.

If `y` is a matrix, then each column corresponds to an independent sample.

```jldoctest
julia> f = GP(Matern32Kernel());

julia> x = randn(11);

julia> y = similar(x);

julia> rand!(f(x), y);

julia> rand!(MersenneTwister(123456), f(x), y);

julia> ys = similar(x, length(x), 3);

julia> rand!(f(x), ys);

julia> rand!(MersenneTwister(123456), f(x), ys);
```
"""
Random.rand!(::AbstractRNG, ::FiniteGP, ::AbstractVecOrMat{<:Real})

# Distributions defines methods for `rand!` (and `rand`) that fall back to `_rand!`
function Distributions._rand!(rng::AbstractRNG, f::FiniteGP, x::AbstractVecOrMat{<:Real})
m, C_mat = mean_and_cov(f)
C = cholesky(_symmetric(C_mat))
lmul!(C.U', randn!(rng, x))
x .+= m
return x
end

"""
logpdf(f::FiniteGP, y::AbstractVecOrMat{<:Real})

Expand Down
23 changes: 18 additions & 5 deletions test/abstract_gp/finite_gp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,21 +48,34 @@ end
@test size(rand(rng, fx, 10)) == (length(x), 10)
@test length(rand(fx)) == length(x)
@test size(rand(fx, 10)) == (length(x), 10)

# Check that `rand!` calls do not error
y = similar(x)
rand!(rng, fx, y)
rand!(fx, y)
ys = similar(x, length(x), 10)
rand!(rng, fx, ys)
rand!(fx, ys)
end
@testset "rand (statistical)" begin
rng = MersenneTwister(123456)
N = 10
m0 = 1
S = 100_000
x = range(-3.0, 3.0; length=N)
f = FiniteGP(GP(1, SqExponentialKernel()), x, 1e-12)
f = FiniteGP(GP(m0, SqExponentialKernel()), x, 1e-12)

# Check mean + covariance estimates approximately converge for single-GP sampling.
f̂ = rand(rng, f, S)
@test maximum(abs.(mean(f̂; dims=2) - mean(f))) < 1e-2
f̂1 = rand(rng, f, S)
f̂2 = similar(f̂1)
rand!(rng, f, f̂2)

for f̂ in (f̂1, f̂2)
@test maximum(abs.(mean(f̂; dims=2) - mean(f))) < 1e-2

Σ′ = (f̂ .- mean(f)) * (f̂ .- mean(f))' ./ S
@test mean(abs.(Σ′ - cov(f))) < 1e-2
Σ′ = (f̂ .- mean(f)) * (f̂ .- mean(f))' ./ S
@test mean(abs.(Σ′ - cov(f))) < 1e-2
end
end
# @testset "rand (gradients)" begin
# rng, N, S = MersenneTwister(123456), 10, 3
Expand Down