Skip to content

Commit

Permalink
Merge pull request #98 from JuliaGaussianProcesses/tgf/flexible-kernel
Browse files Browse the repository at this point in the history
Add ApproximatePeriodicKernel
  • Loading branch information
theogf authored May 9, 2023
2 parents 7a6dc8f + 718b1bd commit a931ac0
Show file tree
Hide file tree
Showing 7 changed files with 182 additions and 28 deletions.
3 changes: 2 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
name = "TemporalGPs"
uuid = "e155a3c4-0841-43e1-8b83-a0e4f03cc18f"
authors = ["willtebbutt <[email protected]> and contributors"]
version = "0.6.4"
version = "0.6.5"

[deps]
AbstractGPs = "99985d1d-32ba-4be9-9821-2ec096f28918"
Bessels = "0e736298-9ec6-45e8-9647-e4fc86a2fe38"
BlockDiagonals = "0a1fb500-61f7-11e9-3c65-f5ef3456f9f0"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
Expand Down
4 changes: 3 additions & 1 deletion src/TemporalGPs.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
module TemporalGPs

using AbstractGPs
using Bessels: besseli
using BlockDiagonals
using ChainRulesCore
import ChainRulesCore: rrule
Expand Down Expand Up @@ -31,7 +32,8 @@ module TemporalGPs
checkpointed,
posterior,
logpdf_and_rand,
Separable
Separable,
ApproxPeriodicKernel

# Various bits-and-bobs. Often commiting some type piracy.
include(joinpath("util", "harmonise.jl"))
Expand Down
87 changes: 87 additions & 0 deletions src/gp/lti_sde.jl
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,93 @@ function stationary_distribution(::Matern52Kernel, ::SArrayStorage{T}) where {T<
return Gaussian(m, P)
end

# Cosine

function to_sde(::CosineKernel, ::SArrayStorage{T}) where {T}
F = SMatrix{2, 2, T}(0, 1, -1, 0)
q = zero(T)
H = SVector{2, T}(1, 0)
return F, q, H
end

function stationary_distribution(::CosineKernel, ::SArrayStorage{T}) where {T<:Real}
m = SVector{2, T}(0, 0)
P = SMatrix{2, 2, T}(1, 0, 0, 1)
return Gaussian(m, P)
end

# Approximate Periodic Kernel
# The periodic kernel is approximated by a sum of cosine kernels with different frequencies.
struct ApproxPeriodicKernel{N,K<:PeriodicKernel} <: KernelFunctions.SimpleKernel
kernel::K
function ApproxPeriodicKernel{N,K}(kernel::K) where {N,K<:PeriodicKernel}
length(kernel.r) == 1 || error("ApproxPeriodicKernel only supports a single lengthscale")
return new{N,K}(kernel)
end
end
# We follow "State Space approximation of Gaussian Processes for time series forecasting"
# by Alessio Benavoli and Giorgio Corani and use a default of 7 Cosine Kernel terms
ApproxPeriodicKernel(;r::Real=1.0) = ApproxPeriodicKernel{7}(PeriodicKernel(;r=[r]))
ApproxPeriodicKernel{N}(;r::Real=1.0) where {N} = ApproxPeriodicKernel{N}(PeriodicKernel(;r=[r]))
ApproxPeriodicKernel(kernel::PeriodicKernel) = ApproxPeriodicKernel{7}(kernel)
ApproxPeriodicKernel{N}(kernel::K) where {N,K<:PeriodicKernel} = ApproxPeriodicKernel{N,K}(kernel)

KernelFunctions.kappa(k::ApproxPeriodicKernel, x) = KernelFunctions.kappa(k.kernel, x)
KernelFunctions.metric(k::ApproxPeriodicKernel) = KernelFunctions.metric(k.kernel)

function Base.show(io::IO, κ::ApproxPeriodicKernel{N}) where {N}
return print(io, "Approximate Periodic Kernel, (r = $(only.kernel.r))) approximated with $N cosine kernels")
end

function lgssm_components(approx::ApproxPeriodicKernel{N}, t::Union{StepRangeLen, RegularSpacing}, storage::StorageType{T}) where {N,T<:Real}
Fs, Hs, ms, Ps = _init_periodic_kernel_lgssm(approx.kernel, storage, N)
nt = length(t)
As = map(F -> Fill(time_exp(F, T(step(t))), nt), Fs)
return _reduce_sum_cosine_kernel_lgssm(As, Hs, ms, Ps, N, nt, T)
end
function lgssm_components(approx::ApproxPeriodicKernel{N}, t::AbstractVector{<:Real}, storage::StorageType{T}) where {N,T<:Real}
Fs, Hs, ms, Ps = _init_periodic_kernel_lgssm(approx.kernel, storage, N)
t = vcat([first(t) - 1], t)
nt = length(diff(t))
As = _map(F -> _map(Δt -> time_exp(F, T(Δt)), diff(t)), Fs)
return _reduce_sum_cosine_kernel_lgssm(As, Hs, ms, Ps, N, nt, T)
end

function _init_periodic_kernel_lgssm(kernel::PeriodicKernel, storage, N::Int=7)
r = kernel.r
l⁻² = inv(4 * only(r)^2)

F, _, H = to_sde(CosineKernel(), storage)
Fs = ntuple(N) do i
2π * (i - 1) * F
end
Hs = Fill(H, N)

x0 = stationary_distribution(CosineKernel(), storage)
ms = Fill(x0.m, N)
P = x0.P
Ps = ntuple(N) do j
qⱼ = (1 + (j !== 1) ) * besseli(j - 1, l⁻²) / exp(l⁻²)
qⱼ * P
end

Fs, Hs, ms, Ps
end

function _reduce_sum_cosine_kernel_lgssm(As, Hs, ms, Ps, N, nt, T)
as = Fill(Fill(Zeros{T}(size(first(first(As)), 1)), nt), N)
Qs = _map((P, A) -> _map(A -> Symmetric(P) - A * Symmetric(P) * A', A), Ps, As)
Hs = Fill(vcat(Hs...), nt)
h = Fill(zero(T), nt)
As = _map(block_diagonal, As...)
as = -map(vcat, as...)
Qs = _map(block_diagonal, Qs...)
m = reduce(vcat, ms)
P = block_diagonal(Ps...)
x0 = Gaussian(m, P)
return As, as, Qs, (Hs, h), x0
end

# Constant

function TemporalGPs.to_sde(::ConstantKernel, ::SArrayStorage{T}) where {T<:Real}
Expand Down
3 changes: 1 addition & 2 deletions src/models/lgssm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -267,8 +267,7 @@ ChainRulesCore.@non_differentiable ident_eps(args...)

_collect(U::Adjoint{<:Any, <:Matrix}) = collect(U)
_collect(U::SMatrix) = U


_collect(U::BlockDiagonal) = U

# AD stuff. No need to understand this unless you're really plumbing the depths...

Expand Down
1 change: 1 addition & 0 deletions src/util/chainrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ function rrule(::Type{<:Fill}, x, sz)
Fill_rrule::Union{Fill,Thunk}) = NoTangent(), FillArrays.getindex_value(unthunk(Δ)), NoTangent()
Fill_rrule::Tangent{T,<:NamedTuple{(:value, :axes)}}) where {T} = NoTangent(), Δ.value, NoTangent()
Fill_rrule(::AbstractZero) = NoTangent(), NoTangent(), NoTangent()
Fill_rrule::Tangent{T,<:NTuple}) where {T} = NoTangent(), sum(Δ), NoTangent()
function Fill_rrule::AbstractArray)
# all(==(first(Δ)), Δ) || error("Δ should be a vector of the same value")
# sum(Δ)
Expand Down
102 changes: 79 additions & 23 deletions test/gp/lti_sde.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
using TemporalGPs: build_lgssm, StorageType, is_of_storage_type
using KernelFunctions
using KernelFunctions: kappa
using ChainRulesTestUtils
using TemporalGPs: build_lgssm, StorageType, is_of_storage_type, lgssm_components
using Test
include("../test_util.jl")
include("../models/model_test_utils.jl")
_logistic(x) = 1 / (1 + exp(-x))
Expand All @@ -12,6 +15,34 @@ function _construction_tester(f_naive::GP, storage::StorageType, σ², t::Abstra
return build_lgssm(fx)
end

@testset "ApproxPeriodicKernel" begin
k = ApproxPeriodicKernel()
@test k isa ApproxPeriodicKernel{7}
# Test that it behaves like a normal PeriodicKernel
k_base = PeriodicKernel()
x = rand()
@test kappa(k, x) == kappa(k_base, x)
x = rand(3)
@test kernelmatrix(k, x) kernelmatrix(k_base, x)
# Test dimensionality of LGSSM components
Nt = 10
@testset "$(typeof(t)), $storage, $N" for t in (
sort(rand(Nt)), RegularSpacing(0.0, 0.1, Nt)
),
storage in (SArrayStorage{Float64}(), ArrayStorage{Float64}()),
N in (5, 8)

k = ApproxPeriodicKernel{N}()
As, as, Qs, emission_projections, x0 = lgssm_components(k, t, storage)
@test length(As) == Nt
@test all(x -> size(x) == (N * 2, N * 2), As)
@test length(as) == Nt
@test all(x -> size(x) == (N * 2,), as)
@test length(Qs) == Nt
@test all(x -> size(x) == (N * 2, N * 2), Qs)
end
end

println("lti_sde:")
@testset "lti_sde" begin
@testset "block_diagonal" begin
Expand All @@ -37,7 +68,11 @@ println("lti_sde:")
)

kernels = [
Matern12Kernel(), Matern32Kernel(), Matern52Kernel(), ConstantKernel(; c=1.5)
Matern12Kernel(),
Matern32Kernel(),
Matern52Kernel(),
ConstantKernel(; c=1.5),
CosineKernel(),
]

@testset "$kernel, $(storage.name)" for kernel in kernels, storage in storages
Expand All @@ -56,53 +91,60 @@ println("lti_sde:")
N = 13
kernels = vcat(
# Base kernels.
(name="base-Matern12Kernel", val=Matern12Kernel()),
(name="base-Matern12Kernel", val=Matern12Kernel(), to_vec_grad=false),
map([Matern32Kernel, Matern52Kernel]) do k
(; name="base-$k", val=k())
(; name="base-$k", val=k(), to_vec_grad=false)
end,

# Scaled kernels.
map([1e-1, 1.0, 10.0, 100.0]) do σ²
(; name="scaled-σ²=$σ²", val=σ² * Matern32Kernel())
(; name="scaled-σ²=$σ²", val=σ² * Matern32Kernel(), to_vec_grad=false)
end,

# Stretched kernels.
map([1e-2, 0.1, 1.0, 10.0, 100.0]) do λ
(; name="stretched-λ=", val=Matern32Kernel() ScaleTransform(λ))
(; name="stretched-λ=", val=Matern32Kernel() ScaleTransform(λ), to_vec_grad=false)
end,

# Approx periodic kernels
map([7, 11]) do N
(
name="approx-periodic-N=$N",
val=ApproxPeriodicKernel{N}(; r=1.0),
to_vec_grad=true,
)
end,
# TEST_TOFIX
# Gradients should be fixed on those composites.
# Error is mostly due do an incompatibility of Tangents
# between Zygote and FiniteDifferences.

# Product kernels
(
name="prod-Matern12Kernel-Matern32Kernel",
val=1.5 * Matern12Kernel() ScaleTransform(0.1) *
Matern32Kernel() ScaleTransform(1.1),
skip_grad=true,
),
(
val=1.5 * Matern12Kernel() ScaleTransform(0.1) * Matern32Kernel()
ScaleTransform(1.1),
to_vec_grad=nothing,
),
(
name="prod-Matern32Kernel-Matern52Kernel-ConstantKernel",
val = 3.0 * Matern32Kernel() *
Matern52Kernel() *
ConstantKernel(),
skip_grad=true,
val=3.0 * Matern32Kernel() * Matern52Kernel() * ConstantKernel(),
to_vec_grad=nothing,
),

# Summed kernels.
(
name="sum-Matern12Kernel-Matern32Kernel",
val=1.5 * Matern12Kernel() ScaleTransform(0.1) +
0.3 * Matern32Kernel() ScaleTransform(1.1),
skip_grad=true,
),
to_vec_grad=nothing,
),
(
name="sum-Matern32Kernel-Matern52Kernel-ConstantKernel",
val = 2.0 * Matern32Kernel() +
val=2.0 * Matern32Kernel() +
0.5 * Matern52Kernel() +
1.0 * ConstantKernel(),
skip_grad=true,
to_vec_grad=nothing,
),
)

Expand All @@ -126,14 +168,14 @@ println("lti_sde:")
(name="Custom Mean", val=CustomMean(x -> 2x)),
)

@testset "$(kernel.name), $(m.name), $(storage.name), $(t.name), $(σ².name)" for
kernel in kernels,
@testset "$(kernel.name), $(m.name), $(storage.name), $(t.name), $(σ².name)" for kernel in
kernels,
m in means,
storage in storages,
t in ts,
σ² in σ²s

println("$(kernel.name), $(storage.name), $(t.name), $(σ².name)")
println("$(kernel.name), $(storage.name), $(m.name), $(t.name), $(σ².name)")

# Construct Gauss-Markov model.
f_naive = GP(m.val, kernel.val)
Expand Down Expand Up @@ -174,7 +216,21 @@ println("lti_sde:")
end

# Just need to ensure we can differentiate through construction properly.
if !(hasfield(typeof(kernel), :skip_grad) && kernel.skip_grad)
if isnothing(kernel.to_vec_grad)
@test_broken "Gradient tests are not passing"
continue
elseif kernel.to_vec_grad
test_zygote_grad_finite_differences_compatible(
_construction_tester,
f_naive,
storage.val,
σ².val,
t.val;
check_inferred=false,
rtol=1e-6,
atol=1e-6,
)
else
test_zygote_grad(
_construction_tester,
f_naive,
Expand Down
10 changes: 9 additions & 1 deletion test/test_util.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ test_zygote_grad(f, args...; check_inferred=false, kwargs...) = test_rrule(Zygot
function test_zygote_grad_finite_differences_compatible(f, args...; kwargs...)
x_vec, from_vec = to_vec(args)
function finite_diff_compatible_f(x::AbstractVector)
return @ignore_derivatives(f)(from_vec(x)...)
return @ignore_derivatives(f)(@ignore_derivatives(from_vec)(x)...)
end
test_zygote_grad(finite_diff_compatible_f NoTangent(), x_vec; testset_name="test_rrule: $(f) on $(typeof.(args))", kwargs...)
end
Expand Down Expand Up @@ -134,6 +134,14 @@ function ChainRulesTestUtils.test_approx(actual::Tangent{<:Fill}, expected, msg=
test_approx(actual.value, expected.value, msg; kwargs...)
end

function to_vec(x::PeriodicKernel)
x, to_r = to_vec(x.r)
function PeriodicKernel_from_vec(x)
return PeriodicKernel(;r=exp.(to_r(x)))
end
log.(x), PeriodicKernel_from_vec
end

to_vec(x::T) where {T} = generic_struct_to_vec(x)

# This is a copy from FiniteDifferences.jl without the try catch
Expand Down

4 comments on commit a931ac0

@theogf
Copy link
Member Author

@theogf theogf commented on a931ac0 May 9, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/83185

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.6.5 -m "<description of version>" a931ac04a9a69b60ac39a323da367bf43f80c489
git push origin v0.6.5

@simsurace
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/91700

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.6.5 -m "<description of version>" a931ac04a9a69b60ac39a323da367bf43f80c489
git push origin v0.6.5

Please sign in to comment.