Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Soss test (update of #62) #135

Merged
merged 14 commits into from
Apr 14, 2021
4 changes: 4 additions & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SampleChainsDynamicHMC = "6d9fd711-e8b2-4778-9c70-c1dfb499d4c4"
Soss = "8ce77f84-9b61-11e8-39ff-d17a774bf41c"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"
Expand All @@ -19,6 +21,8 @@ Documenter = "0.24, 0.25, 0.26"
FillArrays = "0.11"
FiniteDifferences = "0.9.6, 0.10, 0.11, 0.12"
Plots = "1"
SampleChainsDynamicHMC = "0.1"
Soss = "0.17"
Turing = "0.14, 0.15"
Zygote = "0.5, 0.6"
julia = "1.3"
51 changes: 51 additions & 0 deletions test/compat/soss.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
@testset "soss compat" begin
@testset "GP regression" begin
k = SqExponentialKernel()
y = randn(3)
X = randn(3, 1)
x = [rand(1) for _ in 1:3]

gp_regression = Soss.@model X begin
# Priors.
α ~ LogNormal(0.0, 0.1)
ρ ~ LogNormal(0.0, 1.0)
σ² ~ LogNormal(0.0, 1.0)

# Realized covariance function
kernel = α * transform(SqExponentialKernel(), 1 / ρ)
f = GP(kernel)

# Sampling Distribution.
y ~ f(X, σ²)
end

# Test for matrices
m = gp_regression(; X=RowVecs(X))
@test length(Soss.sample(DynamicHMCChain, (m | (y=y,)), 5)) == 5
devmotion marked this conversation as resolved.
Show resolved Hide resolved

# Test for vectors of vector
m = gp_regression(; X=x)
@test length(Soss.sample(DynamicHMCChain, (m | (y=y,)), 5)) == 5
devmotion marked this conversation as resolved.
Show resolved Hide resolved
end
@testset "latent GP regression" begin
X = randn(3, 1)
x = [rand(1) for _ in 1:3]
y = rand.(Poisson.(exp.(randn(3))))

latent_gp_regression = Soss.@model X begin
f = GP(Matern32Kernel())
u ~ f(X)
λ = exp.(u)
y ~ For(eachindex(λ)) do i
Poisson(λ[i])
end
end

m = latent_gp_regression(; X=RowVecs(X))
@test length(Soss.sample(DynamicHMCChain, (m | (y=y,)), 5)) == 5

# Test for vectors of vector
m = latent_gp_regression(; X=x)
@test length(Soss.sample(DynamicHMCChain, (m | (y=y,)), 5)) == 5
end
end
13 changes: 7 additions & 6 deletions test/turing.jl → test/compat/turing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
y = randn(3)
X = randn(3, 1)
x = [rand(1) for _ in 1:3]
@model function GPRegression(y, X)

Turing.@model function GPRegression(y, X)
# Priors.
α ~ LogNormal(0.0, 0.1)
ρ ~ LogNormal(0.0, 1.0)
Expand All @@ -19,26 +20,26 @@
end
# Test for matrices
m = GPRegression(y, RowVecs(X))
@test length(sample(m, HMC(0.5, 1), 5)) == 5
@test length(sample(m, Turing.HMC(0.5, 1), 5)) == 5
# Test for vectors of vector
m = GPRegression(y, x)
@test length(sample(m, HMC(0.5, 1), 5)) == 5
@test length(sample(m, Turing.HMC(0.5, 1), 5)) == 5
end
@testset "latent GP regression" begin
X = randn(3, 1)
x = [rand(1) for _ in 1:3]
y = rand.(Poisson.(exp.(randn(3))))

@model function latent_gp_regression(y, X)
Turing.@model function latent_gp_regression(y, X)
f = GP(Matern32Kernel())
u ~ f(X)
λ = exp.(u)
return y .~ Poisson.(λ)
end
m = latent_gp_regression(y, RowVecs(X))
@test length(sample(m, NUTS(), 5)) == 5
@test length(sample(m, Turing.NUTS(), 5)) == 5
# Test for vectors of vector
m = latent_gp_regression(y, x)
@test length(sample(m, NUTS(), 5)) == 5
@test length(sample(m, Turing.NUTS(), 5)) == 5
end
end
18 changes: 13 additions & 5 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,17 +23,19 @@ using AbstractGPs:

using Documenter
using ChainRulesCore
using Distributions: MvNormal, PDMat
using Distributions: MvNormal, PDMat, Poisson, LogNormal
using FillArrays
using FiniteDifferences
using FiniteDifferences: j′vp, to_vec
using LinearAlgebra
using LinearAlgebra: AbstractTriangular
using Plots
using Random
using SampleChainsDynamicHMC
using Soss: Soss
using Statistics
using Test
using Turing
using Turing: Turing
using Zygote

include("test_util.jl")
Expand Down Expand Up @@ -75,9 +77,15 @@ include("test_util.jl")
println(" ")
@info "Ran deprecation tests"

include("turing.jl")
println(" ")
@info "Ran Turing tests"
@testset "compat" begin
include(joinpath("compat", "turing.jl"))
println(" ")
@info "Ran Turing tests"

include(joinpath("compat", "soss.jl"))
println(" ")
@info "Ran Soss tests"
end

@testset "doctests" begin
DocMeta.setdocmeta!(
Expand Down