Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ApproximatePeriodicKernel #98

Merged
merged 30 commits into from
May 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
395f174
Implement PeriodicKernel
theogf Mar 14, 2023
b5aebda
Allow for sum of many kernels
theogf Mar 14, 2023
76f2ae9
Try using BlockDiagonal
theogf Mar 14, 2023
0ccc3f1
wip block diagonal
theogf Mar 15, 2023
917f354
Create ApproxPeriodicKernel
theogf Mar 21, 2023
91d020f
Merge branch 'master' into tgf/flexible-kernel
theogf Mar 21, 2023
a701793
Formalize ApproxPeriodicKernel
theogf Mar 21, 2023
48c3adc
Adjustments
theogf Mar 21, 2023
8fe3cf5
New tests
theogf Mar 21, 2023
4c3d677
Latest minor changes
theogf Mar 21, 2023
23e179d
Merge branch 'master' into tgf/flexible-kernel
theogf Apr 4, 2023
be81832
Clean up and fixes
theogf Apr 4, 2023
07c7066
Minor changes
theogf Apr 4, 2023
358c18b
WIP ApproxPeriodicKernel
theogf Apr 11, 2023
1aef9c8
Find correct formulation
theogf Apr 11, 2023
2c31442
Remove workaround
theogf Apr 11, 2023
acd633b
Use log transform?
theogf Apr 11, 2023
766868e
Remove log-transform
theogf Apr 11, 2023
06634f4
Update src/gp/lti_sde.jl
theogf Apr 11, 2023
cd2851c
Merge branch 'master' into tgf/flexible-kernel
theogf Apr 11, 2023
b1c1407
Patch bump
theogf Apr 11, 2023
26b8c11
Build new constructor and larger value
theogf Apr 18, 2023
e8c4c40
Merge branch 'tgf/flexible-kernel' of github.com:JuliaGaussianProcess…
theogf Apr 18, 2023
481fc5e
Fix constructors
theogf Apr 18, 2023
bd66ae0
Merge branch 'master' into tgf/flexible-kernel
theogf Apr 25, 2023
1e6cff9
Adjust tests for PeriodicKernel
theogf Apr 25, 2023
0a49410
Patch bump
theogf Apr 25, 2023
0af2f1e
Update Project.toml
theogf Apr 27, 2023
967c1f9
Merge branch 'master' into tgf/flexible-kernel
theogf Apr 27, 2023
718b1bd
Update lti_sde.jl
theogf Apr 27, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
name = "TemporalGPs"
uuid = "e155a3c4-0841-43e1-8b83-a0e4f03cc18f"
authors = ["willtebbutt <[email protected]> and contributors"]
version = "0.6.4"
version = "0.6.5"

[deps]
AbstractGPs = "99985d1d-32ba-4be9-9821-2ec096f28918"
Bessels = "0e736298-9ec6-45e8-9647-e4fc86a2fe38"
BlockDiagonals = "0a1fb500-61f7-11e9-3c65-f5ef3456f9f0"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
Expand Down
4 changes: 3 additions & 1 deletion src/TemporalGPs.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
module TemporalGPs

using AbstractGPs
using Bessels: besseli
using BlockDiagonals
using ChainRulesCore
import ChainRulesCore: rrule
Expand Down Expand Up @@ -31,7 +32,8 @@ module TemporalGPs
checkpointed,
posterior,
logpdf_and_rand,
Separable
Separable,
ApproxPeriodicKernel

# Various bits-and-bobs. Often commiting some type piracy.
include(joinpath("util", "harmonise.jl"))
Expand Down
87 changes: 87 additions & 0 deletions src/gp/lti_sde.jl
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,93 @@ function stationary_distribution(::Matern52Kernel, ::SArrayStorage{T}) where {T<
return Gaussian(m, P)
end

# Cosine

function to_sde(::CosineKernel, ::SArrayStorage{T}) where {T}
F = SMatrix{2, 2, T}(0, 1, -1, 0)
q = zero(T)
H = SVector{2, T}(1, 0)
return F, q, H
end

function stationary_distribution(::CosineKernel, ::SArrayStorage{T}) where {T<:Real}
m = SVector{2, T}(0, 0)
P = SMatrix{2, 2, T}(1, 0, 0, 1)
return Gaussian(m, P)
end

# Approximate Periodic Kernel
# The periodic kernel is approximated by a sum of cosine kernels with different frequencies.
struct ApproxPeriodicKernel{N,K<:PeriodicKernel} <: KernelFunctions.SimpleKernel
kernel::K
function ApproxPeriodicKernel{N,K}(kernel::K) where {N,K<:PeriodicKernel}
length(kernel.r) == 1 || error("ApproxPeriodicKernel only supports a single lengthscale")
return new{N,K}(kernel)
end
end
# We follow "State Space approximation of Gaussian Processes for time series forecasting"
# by Alessio Benavoli and Giorgio Corani and use a default of 7 Cosine Kernel terms
ApproxPeriodicKernel(;r::Real=1.0) = ApproxPeriodicKernel{7}(PeriodicKernel(;r=[r]))
ApproxPeriodicKernel{N}(;r::Real=1.0) where {N} = ApproxPeriodicKernel{N}(PeriodicKernel(;r=[r]))
ApproxPeriodicKernel(kernel::PeriodicKernel) = ApproxPeriodicKernel{7}(kernel)
ApproxPeriodicKernel{N}(kernel::K) where {N,K<:PeriodicKernel} = ApproxPeriodicKernel{N,K}(kernel)

KernelFunctions.kappa(k::ApproxPeriodicKernel, x) = KernelFunctions.kappa(k.kernel, x)
KernelFunctions.metric(k::ApproxPeriodicKernel) = KernelFunctions.metric(k.kernel)

function Base.show(io::IO, κ::ApproxPeriodicKernel{N}) where {N}
return print(io, "Approximate Periodic Kernel, (r = $(only(κ.kernel.r))) approximated with $N cosine kernels")
end

function lgssm_components(approx::ApproxPeriodicKernel{N}, t::Union{StepRangeLen, RegularSpacing}, storage::StorageType{T}) where {N,T<:Real}
Fs, Hs, ms, Ps = _init_periodic_kernel_lgssm(approx.kernel, storage, N)
nt = length(t)
As = map(F -> Fill(time_exp(F, T(step(t))), nt), Fs)
return _reduce_sum_cosine_kernel_lgssm(As, Hs, ms, Ps, N, nt, T)
end
function lgssm_components(approx::ApproxPeriodicKernel{N}, t::AbstractVector{<:Real}, storage::StorageType{T}) where {N,T<:Real}
Fs, Hs, ms, Ps = _init_periodic_kernel_lgssm(approx.kernel, storage, N)
t = vcat([first(t) - 1], t)
nt = length(diff(t))
As = _map(F -> _map(Δt -> time_exp(F, T(Δt)), diff(t)), Fs)
return _reduce_sum_cosine_kernel_lgssm(As, Hs, ms, Ps, N, nt, T)
end

function _init_periodic_kernel_lgssm(kernel::PeriodicKernel, storage, N::Int=7)
r = kernel.r
l⁻² = inv(4 * only(r)^2)

F, _, H = to_sde(CosineKernel(), storage)
Fs = ntuple(N) do i
2π * (i - 1) * F
end
Hs = Fill(H, N)

x0 = stationary_distribution(CosineKernel(), storage)
ms = Fill(x0.m, N)
P = x0.P
Ps = ntuple(N) do j
qⱼ = (1 + (j !== 1) ) * besseli(j - 1, l⁻²) / exp(l⁻²)
qⱼ * P
end

Fs, Hs, ms, Ps
end

function _reduce_sum_cosine_kernel_lgssm(As, Hs, ms, Ps, N, nt, T)
as = Fill(Fill(Zeros{T}(size(first(first(As)), 1)), nt), N)
Qs = _map((P, A) -> _map(A -> Symmetric(P) - A * Symmetric(P) * A', A), Ps, As)
Hs = Fill(vcat(Hs...), nt)
h = Fill(zero(T), nt)
As = _map(block_diagonal, As...)
as = -map(vcat, as...)
Qs = _map(block_diagonal, Qs...)
m = reduce(vcat, ms)
P = block_diagonal(Ps...)
x0 = Gaussian(m, P)
return As, as, Qs, (Hs, h), x0
end

# Constant

function TemporalGPs.to_sde(::ConstantKernel, ::SArrayStorage{T}) where {T<:Real}
Expand Down
3 changes: 1 addition & 2 deletions src/models/lgssm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -267,8 +267,7 @@ ChainRulesCore.@non_differentiable ident_eps(args...)

_collect(U::Adjoint{<:Any, <:Matrix}) = collect(U)
_collect(U::SMatrix) = U


_collect(U::BlockDiagonal) = U

# AD stuff. No need to understand this unless you're really plumbing the depths...

Expand Down
1 change: 1 addition & 0 deletions src/util/chainrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ function rrule(::Type{<:Fill}, x, sz)
Fill_rrule(Δ::Union{Fill,Thunk}) = NoTangent(), FillArrays.getindex_value(unthunk(Δ)), NoTangent()
Fill_rrule(Δ::Tangent{T,<:NamedTuple{(:value, :axes)}}) where {T} = NoTangent(), Δ.value, NoTangent()
Fill_rrule(::AbstractZero) = NoTangent(), NoTangent(), NoTangent()
Fill_rrule(Δ::Tangent{T,<:NTuple}) where {T} = NoTangent(), sum(Δ), NoTangent()
function Fill_rrule(Δ::AbstractArray)
# all(==(first(Δ)), Δ) || error("Δ should be a vector of the same value")
# sum(Δ)
Expand Down
102 changes: 79 additions & 23 deletions test/gp/lti_sde.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
using TemporalGPs: build_lgssm, StorageType, is_of_storage_type
using KernelFunctions
using KernelFunctions: kappa
using ChainRulesTestUtils
using TemporalGPs: build_lgssm, StorageType, is_of_storage_type, lgssm_components
using Test
include("../test_util.jl")
include("../models/model_test_utils.jl")
_logistic(x) = 1 / (1 + exp(-x))
Expand All @@ -12,6 +15,34 @@ function _construction_tester(f_naive::GP, storage::StorageType, σ², t::Abstra
return build_lgssm(fx)
end

@testset "ApproxPeriodicKernel" begin
k = ApproxPeriodicKernel()
@test k isa ApproxPeriodicKernel{7}
# Test that it behaves like a normal PeriodicKernel
k_base = PeriodicKernel()
x = rand()
@test kappa(k, x) == kappa(k_base, x)
x = rand(3)
@test kernelmatrix(k, x) ≈ kernelmatrix(k_base, x)
# Test dimensionality of LGSSM components
Nt = 10
@testset "$(typeof(t)), $storage, $N" for t in (
sort(rand(Nt)), RegularSpacing(0.0, 0.1, Nt)
),
storage in (SArrayStorage{Float64}(), ArrayStorage{Float64}()),
N in (5, 8)

k = ApproxPeriodicKernel{N}()
As, as, Qs, emission_projections, x0 = lgssm_components(k, t, storage)
@test length(As) == Nt
@test all(x -> size(x) == (N * 2, N * 2), As)
@test length(as) == Nt
@test all(x -> size(x) == (N * 2,), as)
@test length(Qs) == Nt
@test all(x -> size(x) == (N * 2, N * 2), Qs)
end
end

println("lti_sde:")
@testset "lti_sde" begin
@testset "block_diagonal" begin
Expand All @@ -37,7 +68,11 @@ println("lti_sde:")
)

kernels = [
Matern12Kernel(), Matern32Kernel(), Matern52Kernel(), ConstantKernel(; c=1.5)
Matern12Kernel(),
Matern32Kernel(),
Matern52Kernel(),
ConstantKernel(; c=1.5),
CosineKernel(),
]

@testset "$kernel, $(storage.name)" for kernel in kernels, storage in storages
Expand All @@ -56,53 +91,60 @@ println("lti_sde:")
N = 13
kernels = vcat(
# Base kernels.
(name="base-Matern12Kernel", val=Matern12Kernel()),
(name="base-Matern12Kernel", val=Matern12Kernel(), to_vec_grad=false),
map([Matern32Kernel, Matern52Kernel]) do k
(; name="base-$k", val=k())
(; name="base-$k", val=k(), to_vec_grad=false)
end,

# Scaled kernels.
map([1e-1, 1.0, 10.0, 100.0]) do σ²
(; name="scaled-σ²=$σ²", val=σ² * Matern32Kernel())
(; name="scaled-σ²=$σ²", val=σ² * Matern32Kernel(), to_vec_grad=false)
end,

# Stretched kernels.
map([1e-2, 0.1, 1.0, 10.0, 100.0]) do λ
(; name="stretched-λ=$λ", val=Matern32Kernel() ∘ ScaleTransform(λ))
(; name="stretched-λ=$λ", val=Matern32Kernel() ∘ ScaleTransform(λ), to_vec_grad=false)
end,

# Approx periodic kernels
map([7, 11]) do N
(
name="approx-periodic-N=$N",
val=ApproxPeriodicKernel{N}(; r=1.0),
to_vec_grad=true,
)
end,
# TEST_TOFIX
# Gradients should be fixed on those composites.
# Error is mostly due do an incompatibility of Tangents
# between Zygote and FiniteDifferences.

# Product kernels
(
name="prod-Matern12Kernel-Matern32Kernel",
val=1.5 * Matern12Kernel() ∘ ScaleTransform(0.1) *
Matern32Kernel() ∘ ScaleTransform(1.1),
skip_grad=true,
),
(
val=1.5 * Matern12Kernel() ∘ ScaleTransform(0.1) * Matern32Kernel() ∘
ScaleTransform(1.1),
to_vec_grad=nothing,
),
(
name="prod-Matern32Kernel-Matern52Kernel-ConstantKernel",
val = 3.0 * Matern32Kernel() *
Matern52Kernel() *
ConstantKernel(),
skip_grad=true,
val=3.0 * Matern32Kernel() * Matern52Kernel() * ConstantKernel(),
to_vec_grad=nothing,
),

# Summed kernels.
(
name="sum-Matern12Kernel-Matern32Kernel",
val=1.5 * Matern12Kernel() ∘ ScaleTransform(0.1) +
0.3 * Matern32Kernel() ∘ ScaleTransform(1.1),
skip_grad=true,
),
to_vec_grad=nothing,
),
(
name="sum-Matern32Kernel-Matern52Kernel-ConstantKernel",
val = 2.0 * Matern32Kernel() +
val=2.0 * Matern32Kernel() +
0.5 * Matern52Kernel() +
1.0 * ConstantKernel(),
skip_grad=true,
to_vec_grad=nothing,
),
)

Expand All @@ -126,14 +168,14 @@ println("lti_sde:")
(name="Custom Mean", val=CustomMean(x -> 2x)),
)

@testset "$(kernel.name), $(m.name), $(storage.name), $(t.name), $(σ².name)" for
kernel in kernels,
@testset "$(kernel.name), $(m.name), $(storage.name), $(t.name), $(σ².name)" for kernel in
kernels,
m in means,
storage in storages,
t in ts,
σ² in σ²s

println("$(kernel.name), $(storage.name), $(t.name), $(σ².name)")
println("$(kernel.name), $(storage.name), $(m.name), $(t.name), $(σ².name)")

# Construct Gauss-Markov model.
f_naive = GP(m.val, kernel.val)
Expand Down Expand Up @@ -174,7 +216,21 @@ println("lti_sde:")
end

# Just need to ensure we can differentiate through construction properly.
if !(hasfield(typeof(kernel), :skip_grad) && kernel.skip_grad)
if isnothing(kernel.to_vec_grad)
@test_broken "Gradient tests are not passing"
continue
elseif kernel.to_vec_grad
test_zygote_grad_finite_differences_compatible(
_construction_tester,
f_naive,
storage.val,
σ².val,
t.val;
check_inferred=false,
rtol=1e-6,
atol=1e-6,
)
else
test_zygote_grad(
_construction_tester,
f_naive,
Expand Down
10 changes: 9 additions & 1 deletion test/test_util.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ test_zygote_grad(f, args...; check_inferred=false, kwargs...) = test_rrule(Zygot
function test_zygote_grad_finite_differences_compatible(f, args...; kwargs...)
x_vec, from_vec = to_vec(args)
function finite_diff_compatible_f(x::AbstractVector)
return @ignore_derivatives(f)(from_vec(x)...)
return @ignore_derivatives(f)(@ignore_derivatives(from_vec)(x)...)
end
test_zygote_grad(finite_diff_compatible_f ⊢ NoTangent(), x_vec; testset_name="test_rrule: $(f) on $(typeof.(args))", kwargs...)
end
Expand Down Expand Up @@ -134,6 +134,14 @@ function ChainRulesTestUtils.test_approx(actual::Tangent{<:Fill}, expected, msg=
test_approx(actual.value, expected.value, msg; kwargs...)
end

function to_vec(x::PeriodicKernel)
x, to_r = to_vec(x.r)
function PeriodicKernel_from_vec(x)
return PeriodicKernel(;r=exp.(to_r(x)))
end
log.(x), PeriodicKernel_from_vec
end

to_vec(x::T) where {T} = generic_struct_to_vec(x)

# This is a copy from FiniteDifferences.jl without the try catch
Expand Down