Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Base implementation of SVGP #9

Merged
merged 69 commits into from
Jul 30, 2021
Merged
Show file tree
Hide file tree
Changes from 66 commits
Commits
Show all changes
69 commits
Select commit Hold shift + click to select a range
3cec6a0
Add files
rossviljoen Jun 19, 2021
6c814a6
Fixed KL and posterior covariance.
rossviljoen Jun 20, 2021
798f77a
Update example to use Flux
rossviljoen Jun 20, 2021
0641423
Remove Flux as a dep & factor out expected_loglik
rossviljoen Jun 21, 2021
1e4fc90
Update example to use basic Flux layer
rossviljoen Jun 21, 2021
bb42044
Add minibatching.
rossviljoen Jun 22, 2021
102d812
Improved variance calculation.
rossviljoen Jun 23, 2021
3089d93
Initial quadrature implementation
rossviljoen Jun 24, 2021
59474c5
Moved quadrature to new file.
rossviljoen Jul 2, 2021
25e6627
Fixed AD for quadrature.
rossviljoen Jul 3, 2021
54b5470
Fixed AD for KL divergence.
rossviljoen Jul 3, 2021
5e1c882
Added classification example.
rossviljoen Jul 4, 2021
ce20eba
Updated examples.
rossviljoen Jul 4, 2021
359b3d5
Renamed SVGPLayer to SVGPModel.
rossviljoen Jul 4, 2021
3bdbedb
Added basic test structure.
rossviljoen Jul 4, 2021
cb3a341
Started equivalence tests
rossviljoen Jul 4, 2021
3a2c8a9
First pass (doesn't work yet)
rossviljoen Jul 4, 2021
005f8f0
Working tests
rossviljoen Jul 6, 2021
443a2d4
Fixed KL divergence
rossviljoen Jul 6, 2021
92da73c
Refactored elbo stuff
rossviljoen Jul 6, 2021
7d05d1b
Fixed elbo mistakes
rossviljoen Jul 7, 2021
c0dd737
Remove type restiction in ELBO
rossviljoen Jul 7, 2021
92dcdf5
Infer batch size
rossviljoen Jul 7, 2021
f8086c8
Merge branch 'master' into ross/tests
rossviljoen Jul 7, 2021
787c57d
Merge branch 'dev' into base_implementation
rossviljoen Jul 10, 2021
ec5fa05
Added docstrings to elbo.jl
rossviljoen Jul 13, 2021
6d4e87b
Added cross-covariance
rossviljoen Jul 13, 2021
22c999a
Removed unnecessary dependencies
rossviljoen Jul 13, 2021
2763972
Updated regression example
rossviljoen Jul 13, 2021
23e5c2e
Added exact posterior tests
rossviljoen Jul 14, 2021
60d5072
Merge pull request #6 from rossviljoen/ross/tests
rossviljoen Jul 14, 2021
a8e5cbe
Address review comments
rossviljoen Jul 14, 2021
1bbeae0
Fix docstrings
rossviljoen Jul 16, 2021
1a0782f
Rename kldivergence
rossviljoen Jul 18, 2021
eddc7ab
Factor out exact posterior
rossviljoen Jul 19, 2021
7ea3c2f
Use AbstractGPs TestUtils
rossviljoen Jul 19, 2021
9b6557f
Added support for prior mean function
rossviljoen Jul 19, 2021
0e59e49
Added MC expectation and refactored elbo
rossviljoen Jul 21, 2021
38ed15f
Updated docstrings
rossviljoen Jul 21, 2021
c8a974f
Dispatch on types instead of symbols
rossviljoen Jul 21, 2021
56507a8
Update doctrings
rossviljoen Jul 21, 2021
857ecc3
Enforce type for MonteCarlo and GaussHermite
rossviljoen Jul 21, 2021
1bbf385
Added error for Analytic
rossviljoen Jul 21, 2021
bbd8502
Rename GaussHermite to Quadrature
rossviljoen Jul 21, 2021
0563d01
Assume homoscedastic Gaussian noise
rossviljoen Jul 21, 2021
fb9a563
Add tests for `expected_loglik`
rossviljoen Jul 21, 2021
e62fbf7
Require ExpLink for Poisson closed form
rossviljoen Jul 24, 2021
36c62b9
Better error message
rossviljoen Jul 24, 2021
0ee1004
Added close form for Gamma and Exponential
rossviljoen Jul 24, 2021
f648a7c
Fix docstring
rossviljoen Jul 24, 2021
a9b9a57
Update docstring
rossviljoen Jul 24, 2021
b8e7d6b
Fix docstring
rossviljoen Jul 26, 2021
9353e44
Restrict types for continuous distributions
rossviljoen Jul 26, 2021
ea3d3c6
Use `AbstractGPs.approx_posterior` and `elbo`
rossviljoen Jul 26, 2021
c1a4546
Minor formatting
rossviljoen Jul 27, 2021
835da22
Dispatch on filled diagonal matrix obs noise
rossviljoen Jul 27, 2021
fa1cdc3
Add elbo tests
rossviljoen Jul 27, 2021
af41ca3
Small test changes
rossviljoen Jul 27, 2021
de2c4cd
Fix elbo error
rossviljoen Jul 27, 2021
f07c6f1
Remove qualifier from kldivergence
rossviljoen Jul 28, 2021
9f4d295
Check for ZeroMean
rossviljoen Jul 28, 2021
ca5f148
Fix classification example jitter
rossviljoen Jul 28, 2021
66ec256
Remove unnecessary imports from AbstractGPs
rossviljoen Jul 28, 2021
6841074
Better cholesky of covariance methods
rossviljoen Jul 28, 2021
1594ee8
Use KLDivergences
rossviljoen Jul 28, 2021
878b214
Use vector of marginals `q_f` vs. `f_mean, f_var`
rossviljoen Jul 28, 2021
be96722
Ran JuliaFormatter
rossviljoen Jul 30, 2021
39f243a
Revert "Ran JuliaFormatter"
rossviljoen Jul 30, 2021
ef3292c
Reformat with JuliaFormatter - BlueStyle
rossviljoen Jul 30, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
name = "SparseGPs"
uuid = "298c2ebc-0411-48ad-af38-99e88101b606"
authors = ["Ross Viljoen <[email protected]>"]
version = "0.1.0"

[deps]
AbstractGPs = "99985d1d-32ba-4be9-9821-2ec096f28918"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
FastGaussQuadrature = "442a2c76-b920-505d-bb47-c5924d526838"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
GPLikelihoods = "6031954c-0455-49d7-b3b9-3e1c99afaf40"
KLDivergences = "3c9cd921-3d3f-41e2-830c-e020174918cc"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
160 changes: 160 additions & 0 deletions examples/classification.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
# Recreation of https://gpflow.readthedocs.io/en/master/notebooks/basics/classification.html

# %%
using SparseGPs
using AbstractGPs
using GPLikelihoods
using StatsFuns
using FastGaussQuadrature
using Distributions
using LinearAlgebra
using DelimitedFiles
using IterTools

using Plots
default(; legend=:outertopright, size=(700, 400))

using Random
Random.seed!(1234)

# %%
# Read in the classification data
data_file = pkgdir(SparseGPs) * "/examples/data/classif_1D.csv"
x, y = eachcol(readdlm(data_file))
scatter(x, y)


# %%
# First, create the GP kernel from given parameters k
function make_kernel(k)
return softplus(k[1]) * (SqExponentialKernel() ∘ ScaleTransform(softplus(k[2])))
end

k = [10, 0.1]

kernel = make_kernel(k)
f = LatentGP(GP(kernel), BernoulliLikelihood(), 0.1)
rossviljoen marked this conversation as resolved.
Show resolved Hide resolved
fx = f(x)


# %%
# Then, plot some samples from the prior underlying GP
x_plot = 0:0.02:6
prior_f_samples = rand(f.f(x_plot, 1e-6),20)

plt = plot(
x_plot,
prior_f_samples;
seriescolor="red",
linealpha=0.2,
label=""
)
scatter!(plt, x, y; seriescolor="blue", label="Data points")


# %%
# Plot the same samples, but pushed through a logistic sigmoid to constrain
# them in (0, 1).
prior_y_samples = mean.(f.lik.(prior_f_samples))
rossviljoen marked this conversation as resolved.
Show resolved Hide resolved

plt = plot(
x_plot,
prior_y_samples;
seriescolor="red",
linealpha=0.2,
label=""
)
scatter!(plt, x, y; seriescolor="blue", label="Data points")


# %%
# A simple Flux model
using Flux

struct SVGPModel
k # kernel parameters
m # variational mean
A # variational covariance
z # inducing points
end

@Flux.functor SVGPModel (k, m, A,) # Don't train the inducing inputs

lik = BernoulliLikelihood()
jitter = 1e-4

function (m::SVGPModel)(x)
kernel = make_kernel(m.k)
f = LatentGP(GP(kernel), lik, jitter)
q = MvNormal(m.m, m.A'm.A)
fx = f(x)
fu = f(m.z).fx
return fx, fu, q
end

function flux_loss(x, y; n_data=length(y))
fx, fu, q = model(x)
return -SparseGPs.elbo(fx, y, fu, q; n_data, method=MonteCarlo())
end

# %%
M = 15 # number of inducing points

# Initialise the parameters
k = [10, 0.1]
m = zeros(M)
A = Matrix{Float64}(I, M, M)
z = x[1:M]

model = SVGPModel(k, m, A, z)

opt = ADAM(0.1)
parameters = Flux.params(model)

# %%
# Negative ELBO before training
println(flux_loss(x, y))

# %%
# Train the model
Flux.train!(
(x, y) -> flux_loss(x, y),
parameters,
ncycle([(x, y)], 2000), # Train for 1000 epochs
opt
)

# %%
# Negative ELBO after training
println(flux_loss(x, y))

# %%
# After optimisation, plot samples from the underlying posterior GP.
fu = f(z).fx # want the underlying FiniteGP
post = SparseGPs.approx_posterior(SVGP(), fu, MvNormal(m, A'A))
l_post = LatentGP(post, BernoulliLikelihood(), jitter)

post_f_samples = rand(l_post.f(x_plot, 1e-6), 20)

rossviljoen marked this conversation as resolved.
Show resolved Hide resolved
plt = plot(
x_plot,
post_f_samples;
seriescolor="red",
linealpha=0.2,
legend=false
)

# %%
# As above, push these samples through a logistic sigmoid to get posterior predictions.
post_y_samples = mean.(l_post.lik.(post_f_samples))

plt = plot(
x_plot,
post_y_samples;
seriescolor="red",
linealpha=0.2,
# legend=false,
label=""
)
scatter!(plt, x, y; seriescolor="blue", label="Data points")
vline!(z; label="Pseudo-points")
50 changes: 50 additions & 0 deletions examples/data/classif_1D.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
5.668341708542713242e+00 0.000000000000000000e+00
5.758793969849246075e+00 0.000000000000000000e+00
5.517587939698492150e+00 0.000000000000000000e+00
2.954773869346733584e+00 1.000000000000000000e+00
3.648241206030150785e+00 1.000000000000000000e+00
2.110552763819095290e+00 1.000000000000000000e+00
4.613065326633165597e+00 0.000000000000000000e+00
4.793969849246231263e+00 0.000000000000000000e+00
4.703517587939698430e+00 0.000000000000000000e+00
6.030150753768843686e-01 1.000000000000000000e+00
3.015075376884421843e-01 0.000000000000000000e+00
3.979899497487437099e+00 0.000000000000000000e+00
3.226130653266331638e+00 1.000000000000000000e+00
1.899497487437185939e+00 1.000000000000000000e+00
1.145728643216080256e+00 1.000000000000000000e+00
3.316582914572864249e-01 0.000000000000000000e+00
6.030150753768843686e-01 1.000000000000000000e+00
2.231155778894472252e+00 1.000000000000000000e+00
3.256281407035175768e+00 1.000000000000000000e+00
1.085427135678391997e+00 1.000000000000000000e+00
1.809045226130653106e+00 1.000000000000000000e+00
4.492462311557789079e+00 0.000000000000000000e+00
1.959798994974874198e+00 1.000000000000000000e+00
0.000000000000000000e+00 0.000000000000000000e+00
3.346733668341708601e+00 1.000000000000000000e+00
1.507537688442210921e-01 0.000000000000000000e+00
1.809045226130653328e-01 1.000000000000000000e+00
5.517587939698492150e+00 0.000000000000000000e+00
2.201005025125628123e+00 1.000000000000000000e+00
5.577889447236180409e+00 0.000000000000000000e+00
1.809045226130653328e-01 0.000000000000000000e+00
1.688442211055276365e+00 1.000000000000000000e+00
4.160804020100502321e+00 0.000000000000000000e+00
2.170854271356783993e+00 1.000000000000000000e+00
4.311557788944723413e+00 0.000000000000000000e+00
3.075376884422110546e+00 1.000000000000000000e+00
5.125628140703517133e+00 0.000000000000000000e+00
1.989949748743718549e+00 1.000000000000000000e+00
5.366834170854271058e+00 0.000000000000000000e+00
4.100502512562814061e+00 0.000000000000000000e+00
7.236180904522613311e-01 1.000000000000000000e+00
2.261306532663316382e+00 1.000000000000000000e+00
3.467336683417085119e+00 1.000000000000000000e+00
1.085427135678391997e+00 1.000000000000000000e+00
5.095477386934673447e+00 0.000000000000000000e+00
5.185929648241205392e+00 0.000000000000000000e+00
2.743718592964823788e+00 1.000000000000000000e+00
2.773869346733668362e+00 1.000000000000000000e+00
1.417085427135678311e+00 1.000000000000000000e+00
1.989949748743718549e+00 1.000000000000000000e+00
Loading