From a169fcc7e33a788504a1b521c9db46f8a3bbb1dd Mon Sep 17 00:00:00 2001 From: Will Tebbutt Date: Sat, 3 Aug 2024 11:09:18 +0100 Subject: [PATCH 01/21] Remove literal_getfield usage --- src/gp/lti_sde.jl | 57 +++++++++------------- src/gp/posterior_lti_sde.jl | 19 +++----- src/models/gauss_markov_model.jl | 2 +- src/models/lgssm.jl | 29 ++++------- src/models/linear_gaussian_conditionals.jl | 47 +++++------------- src/models/missings.jl | 2 +- src/space_time/rectilinear_grid.jl | 4 +- src/space_time/regular_in_time.jl | 4 +- src/util/gaussian.jl | 4 +- test/runtests.jl | 2 +- 10 files changed, 57 insertions(+), 113 deletions(-) diff --git a/src/gp/lti_sde.jl b/src/gp/lti_sde.jl index 0cf06737..03d7f7ed 100644 --- a/src/gp/lti_sde.jl +++ b/src/gp/lti_sde.jl @@ -71,25 +71,19 @@ end function build_lgssm(f::LTISDE, x::AbstractVector, Σys::AbstractVector) m = get_mean(f) k = get_kernel(f) - s = Zygote.literal_getfield(f, Val(:storage)) - As, as, Qs, emission_proj, x0 = lgssm_components(m, k, x, s) + As, as, Qs, emission_proj, x0 = lgssm_components(m, k, x, f.storage) return LGSSM( GaussMarkovModel(Forward(), As, as, Qs, x0), build_emissions(emission_proj, Σys), ) end -function build_lgssm(ft::FiniteLTISDE) - f = Zygote.literal_getfield(ft, Val(:f)) - x = Zygote.literal_getfield(ft, Val(:x)) - Σys = noise_var_to_time_form(x, Zygote.literal_getfield(ft, Val(:Σy))) - return build_lgssm(f, x, Σys) -end +build_lgssm(ft::FiniteLTISDE) = build_lgssm(ft.f, ft.x, noise_var_to_time_form(ft.x, ft.Σy)) -get_mean(f::LTISDE) = get_mean(Zygote.literal_getfield(f, Val(:f))) -get_mean(f::GP) = Zygote.literal_getfield(f, Val(:mean)) +get_mean(f::LTISDE) = get_mean(f.f) +get_mean(f::GP) = f.mean -get_kernel(f::LTISDE) = get_kernel(Zygote.literal_getfield(f, Val(:f))) -get_kernel(f::GP) = Zygote.literal_getfield(f, Val(:kernel)) +get_kernel(f::LTISDE) = get_kernel(f.f) +get_kernel(f::GP) = f.kernel function build_emissions( (Hs, hs)::Tuple{AbstractVector, AbstractVector}, Σs::AbstractVector, @@ -332,20 +326,18 @@ end # Scaled function to_sde(k::ScaledKernel, storage::StorageType{T}) where {T<:Real} - _k = Zygote.literal_getfield(k, Val(:kernel)) - σ² = Zygote.literal_getfield(k, Val(:σ²)) - F, q, H = to_sde(_k, storage) - σ = sqrt(convert(eltype(storage), only(σ²))) + F, q, H = to_sde(k.kernel, storage) + σ = sqrt(convert(eltype(storage), only(k.σ²))) return F, σ^2 * q, σ * H end -stationary_distribution(k::ScaledKernel, storage::StorageType) = stationary_distribution(Zygote.literal_getfield(k, Val(:kernel)), storage) +function stationary_distribution(k::ScaledKernel, storage::StorageType) + return stationary_distribution(k.kernel, storage) +end function lgssm_components(k::ScaledKernel, ts::AbstractVector, storage_type::StorageType) - _k = Zygote.literal_getfield(k, Val(:kernel)) - σ² = Zygote.literal_getfield(k, Val(:σ²)) - As, as, Qs, emission_proj, x0 = lgssm_components(_k, ts, storage_type) - σ = sqrt(convert(eltype(storage_type), only(σ²))) + As, as, Qs, emission_proj, x0 = lgssm_components(k.kernel, ts, storage_type) + σ = sqrt(convert(eltype(storage_type), only(k.σ²))) return As, as, Qs, _scale_emission_projections(emission_proj, σ), x0 end @@ -360,34 +352,29 @@ end # Stretched function to_sde(k::TransformedKernel{<:Kernel, <:ScaleTransform}, storage::StorageType) - _k = Zygote.literal_getfield(k, Val(:kernel)) - s = Zygote.literal_getfield(Zygote.literal_getfield(k, Val(:transform)), Val(:s)) - F, q, H = to_sde(_k, storage) - return F * only(s), q, H + F, q, H = to_sde(k.kernel, storage) + return F * only(k.transform.s), q, H end -stationary_distribution(k::TransformedKernel{<:Kernel, <:ScaleTransform}, storage::StorageType) = stationary_distribution(Zygote.literal_getfield(k, Val(:kernel)), storage) +function stationary_distribution( + k::TransformedKernel{<:Kernel, <:ScaleTransform}, storage::StorageType +) + return stationary_distribution(k.kernel, storage) +end function lgssm_components( k::TransformedKernel{<:Kernel, <:ScaleTransform}, ts::AbstractVector, storage_type::StorageType, ) - _k = Zygote.literal_getfield(k, Val(:kernel)) - s = Zygote.literal_getfield(Zygote.literal_getfield(k, Val(:transform)), Val(:s)) - return lgssm_components(_k, apply_stretch(s[1], ts), storage_type) + return lgssm_components(k.kernel, apply_stretch(k.transform.s[1], ts), storage_type) end apply_stretch(a, ts::AbstractVector{<:Real}) = a * ts apply_stretch(a, ts::StepRangeLen) = a * ts -function apply_stretch(a, ts::RegularSpacing) - t0 = Zygote.literal_getfield(ts, Val(:t0)) - Δt = Zygote.literal_getfield(ts, Val(:Δt)) - N = Zygote.literal_getfield(ts, Val(:N)) - return RegularSpacing(a * t0, a * Δt, N) -end +apply_stretch(a, ts::RegularSpacing) = RegularSpacing(a * ts.t0, a * ts.Δt, ts.N) # Product diff --git a/src/gp/posterior_lti_sde.jl b/src/gp/posterior_lti_sde.jl index 007fa389..94aad0ea 100644 --- a/src/gp/posterior_lti_sde.jl +++ b/src/gp/posterior_lti_sde.jl @@ -25,19 +25,12 @@ function AbstractGPs.marginals(fx::FinitePosteriorLTISDE) model_post = replace_observation_noise_cov(posterior(model, ys), σ²s_pr_full) return destructure(x, map(marginals, marginals(model_post))[pr_indices]) else - f = Zygote.literal_getfield(fx, Val(:f)) - prior = Zygote.literal_getfield(f, Val(:prior)) - x = Zygote.literal_getfield(fx, Val(:x)) - data = Zygote.literal_getfield(f, Val(:data)) - Σy = Zygote.literal_getfield(data, Val(:Σy)) - Σy_diag = Zygote.literal_getfield(Σy, Val(:diag)) - y = Zygote.literal_getfield(data, Val(:y)) - - Σy_new = Zygote.literal_getfield(fx, Val(:Σy)) - - model = build_lgssm(AbstractGPs.FiniteGP(prior, x, Σy)) - Σys_new = noise_var_to_time_form(x, Σy_new) - ys = observations_to_time_form(x, y) + f = fx.f + x = fx.x + data = f.data + model = build_lgssm(AbstractGPs.FiniteGP(f.prior, x, data.Σy)) + Σys_new = noise_var_to_time_form(x, fx.Σy) + ys = observations_to_time_form(x, data.y) model_post = replace_observation_noise_cov(posterior(model, ys), Σys_new) return destructure(x, map(marginals, marginals(model_post))) end diff --git a/src/models/gauss_markov_model.jl b/src/models/gauss_markov_model.jl index 7b57c262..4c2d62cf 100644 --- a/src/models/gauss_markov_model.jl +++ b/src/models/gauss_markov_model.jl @@ -65,7 +65,7 @@ function is_of_storage_type(model::GaussMarkovModel, s::StorageType) return is_of_storage_type((model.As, model.as, model.Qs, model.x0), s) end -x0(model::GaussMarkovModel) = Zygote.literal_getfield(model, Val(:x0)) +x0(model::GaussMarkovModel) = model.x0 function get_adjoint_storage(x::GaussMarkovModel, n::Int, Δx::Tangent{T,<:NamedTuple{(:A, :a, :Q)}}) where {T} return ( diff --git a/src/models/lgssm.jl b/src/models/lgssm.jl index e27fe9fc..9c211b77 100644 --- a/src/models/lgssm.jl +++ b/src/models/lgssm.jl @@ -11,13 +11,9 @@ struct LGSSM{Ttransitions<:GaussMarkovModel, Temissions<:StructArray} <: Abstrac emissions::Temissions end -@inline function transitions(model::LGSSM) - return Zygote.literal_getfield(model, Val(:transitions)) -end +@inline transitions(model::LGSSM) = model.transitions -@inline function emissions(model::LGSSM) - return Zygote.literal_getfield(model, Val(:emissions)) -end +@inline emissions(model::LGSSM) = model.emissions @inline ordering(model::LGSSM) = ordering(transitions(model)) ChainRulesCore.@non_differentiable ordering(model) @@ -58,17 +54,11 @@ struct ElementOfLGSSM{Tordering, Ttransition, Temission} emission::Temission end -@inline function ordering(x::ElementOfLGSSM) - return Zygote.literal_getfield(x, Val(:ordering)) -end +@inline ordering(x::ElementOfLGSSM) = x.ordering -@inline function transition_dynamics(x::ElementOfLGSSM) - return Zygote.literal_getfield(x, Val(:transition)) -end +@inline transition_dynamics(x::ElementOfLGSSM) = x.transition -@inline function emission_dynamics(x::ElementOfLGSSM) - return Zygote.literal_getfield(x, Val(:emission)) -end +@inline emission_dynamics(x::ElementOfLGSSM) = x.emission @inline function Base.getindex(model::LGSSM, n::Int) return ElementOfLGSSM(ordering(model), model.transitions[n], model.emissions[n]) @@ -206,11 +196,10 @@ end function posterior(prior::LGSSM, y::AbstractVector) _check_inputs(prior, y) new_trans, xf = _a_bit_of_posterior(prior, y) - A = zygote_friendly_map(x -> Zygote.literal_getfield(x, Val(:A)), new_trans) - a = zygote_friendly_map(x -> Zygote.literal_getfield(x, Val(:a)), new_trans) - Q = zygote_friendly_map(x -> Zygote.literal_getfield(x, Val(:Q)), new_trans) - ems = Zygote.literal_getfield(prior, Val(:emissions)) - return LGSSM(GaussMarkovModel(reverse(ordering(prior)), A, a, Q, xf), ems) + A = zygote_friendly_map(x -> x.A, new_trans) + a = zygote_friendly_map(x -> x.a, new_trans) + Q = zygote_friendly_map(x -> x.Q, new_trans) + return LGSSM(GaussMarkovModel(reverse(ordering(prior)), A, a, Q, xf), prior.emissions) end function _check_inputs(prior, y) diff --git a/src/models/linear_gaussian_conditionals.jl b/src/models/linear_gaussian_conditionals.jl index 0b5fe798..62f65f69 100644 --- a/src/models/linear_gaussian_conditionals.jl +++ b/src/models/linear_gaussian_conditionals.jl @@ -126,14 +126,9 @@ dim_out(f::SmallOutputLGC) = size(f.A, 1) dim_in(f::SmallOutputLGC) = size(f.A, 2) -noise_cov(f::SmallOutputLGC) = Zygote.literal_getfield(f, Val(:Q)) +noise_cov(f::SmallOutputLGC) = f.Q -function get_fields(f::SmallOutputLGC) - A = Zygote.literal_getfield(f, Val(:A)) - a = Zygote.literal_getfield(f, Val(:a)) - Q = Zygote.literal_getfield(f, Val(:Q)) - return A, a, Q -end +get_fields(f::SmallOutputLGC) = (f.A, f.a, f.Q) function posterior_and_lml(x::Gaussian, f::SmallOutputLGC, y::AbstractVector{<:Real}) m, P = get_fields(x) @@ -191,14 +186,9 @@ dim_out(f::LargeOutputLGC) = size(f.A, 1) dim_in(f::LargeOutputLGC) = size(f.A, 2) -noise_cov(f::LargeOutputLGC) = Zygote.literal_getfield(f, Val(:Q)) +noise_cov(f::LargeOutputLGC) = f.Q -function get_fields(f::LargeOutputLGC) - A = Zygote.literal_getfield(f, Val(:A)) - a = Zygote.literal_getfield(f, Val(:a)) - Q = Zygote.literal_getfield(f, Val(:Q)) - return A, a, Q -end +get_fields(f::LargeOutputLGC) = (f.A, f.a, f.Q) function posterior_and_lml(x::Gaussian, f::LargeOutputLGC, y::AbstractVector{<:Real}) m, _P = get_fields(x) @@ -258,18 +248,12 @@ dim_out(f::ScalarOutputLGC) = 1 dim_in(f::ScalarOutputLGC) = size(f.A, 2) -function get_fields(f::ScalarOutputLGC) - A = Zygote.literal_getfield(f, Val(:A)) - a = Zygote.literal_getfield(f, Val(:a)) - Q = Zygote.literal_getfield(f, Val(:Q)) - return A, a, Q -end +get_fields(f::ScalarOutputLGC) = (f.A, f.a, f.Q) -noise_cov(f::ScalarOutputLGC) = Zygote.literal_getfield(f, Val(:Q)) +noise_cov(f::ScalarOutputLGC) = f.Q function conditional_rand(ε::Real, f::ScalarOutputLGC, x::AbstractVector) - A, a, Q = get_fields(f) - return (A * x + a) + sqrt(Q) * ε + return (f.A * x + f.a) + sqrt(f.Q) * ε end ε_randn(rng::AbstractRNG, f::ScalarOutputLGC) = randn(rng, eltype(f)) @@ -323,16 +307,9 @@ dim_out(f::BottleneckLGC) = dim_out(f.fan_out) dim_in(f::BottleneckLGC) = size(f.H, 2) -noise_cov(f::BottleneckLGC) = noise_cov(Zygote.literal_getfield(f, Val(:fan_out))) +noise_cov(f::BottleneckLGC) = noise_cov(f.fan_out) -function get_fields(f::BottleneckLGC) - H = Zygote.literal_getfield(f, Val(:H)) - h = Zygote.literal_getfield(f, Val(:h)) - fan_out = Zygote.literal_getfield(f, Val(:fan_out)) - return H, h, fan_out -end - -fan_out(f::BottleneckLGC) = Zygote.literal_getfield(f, Val(:fan_out)) +get_fields(f::BottleneckLGC) = (f.H, f.h, f.fan_out) function conditional_rand(ε::AbstractVector{<:Real}, f::BottleneckLGC, x::AbstractVector) H, h, fan_out = get_fields(f) @@ -348,12 +325,10 @@ function _project(x::Gaussian, f::BottleneckLGC) return Gaussian(H * m + h, H * P * H' + ident_eps(x)) end -function predict(x::Gaussian, f::BottleneckLGC) - return predict(_project(x, f), fan_out(f)) -end +predict(x::Gaussian, f::BottleneckLGC) = predict(_project(x, f), f.fan_out) function predict_marginals(x::Gaussian, f::BottleneckLGC) - return predict_marginals(_project(x, f), fan_out(f)) + return predict_marginals(_project(x, f), f.fan_out) end function posterior_and_lml(x::Gaussian, f::BottleneckLGC, y::AbstractVector) diff --git a/src/models/missings.jl b/src/models/missings.jl index fd2e2e92..18ffb584 100644 --- a/src/models/missings.jl +++ b/src/models/missings.jl @@ -79,7 +79,7 @@ function _fill_in_missings(Σs::Vector, y::AbstractVector{Union{Missing, T}}) wh end function fill_in_missings(Σ::Diagonal, y::AbstractVector{<:Union{Missing, <:Real}}) - Σ_diag_filled, y_filled = fill_in_missings(Zygote.literal_getfield(Σ, Val(:diag)), y) + Σ_diag_filled, y_filled = fill_in_missings(Σ.diag, y) return Diagonal(Σ_diag_filled), y_filled end diff --git a/src/space_time/rectilinear_grid.jl b/src/space_time/rectilinear_grid.jl index cc7558f8..786bb078 100644 --- a/src/space_time/rectilinear_grid.jl +++ b/src/space_time/rectilinear_grid.jl @@ -15,9 +15,9 @@ struct RectilinearGrid{ xr::Txr end -get_space(x::RectilinearGrid) = Zygote.literal_getfield(x, Val(:xl)) +get_space(x::RectilinearGrid) = x.xl -get_times(x::RectilinearGrid) = Zygote.literal_getfield(x, Val(:xr)) +get_times(x::RectilinearGrid) = x.xr Base.size(X::RectilinearGrid) = (length(X.xl) * length(X.xr),) diff --git a/src/space_time/regular_in_time.jl b/src/space_time/regular_in_time.jl index c3abac1b..3c823118 100644 --- a/src/space_time/regular_in_time.jl +++ b/src/space_time/regular_in_time.jl @@ -12,9 +12,9 @@ struct RegularInTime{ vs::Tvs end -get_space(x::RegularInTime) = Zygote.literal_getfield(x, Val(:vs)) +get_space(x::RegularInTime) = x.vs -get_times(x::RegularInTime) = Zygote.literal_getfield(x, Val(:ts)) +get_times(x::RegularInTime) = x.ts Base.size(x::RegularInTime) = (sum(length, x.vs), ) diff --git a/src/util/gaussian.jl b/src/util/gaussian.jl index 25c531fe..6f5bf00c 100644 --- a/src/util/gaussian.jl +++ b/src/util/gaussian.jl @@ -20,9 +20,9 @@ end dim(x::Gaussian) = length(x.m) -AbstractGPs.mean(x::Gaussian) = Zygote.literal_getfield(x, Val(:m)) +AbstractGPs.mean(x::Gaussian) = x.m -AbstractGPs.cov(x::Gaussian) = Zygote.literal_getfield(x, Val(:P)) +AbstractGPs.cov(x::Gaussian) = x.P AbstractGPs.var(x::Gaussian{<:AbstractVector}) = diag(cov(x)) diff --git a/test/runtests.jl b/test/runtests.jl index 0620b706..ce79abe6 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -7,7 +7,7 @@ ENV["TESTING"] = "TRUE" # Select any of this to test a particular aspect. # To test everything, simply set GROUP to "all" # ENV["GROUP"] = "test gp" -const GROUP = get(ENV, "GROUP", "test") +const GROUP = get(ENV, "GROUP", "all") OUTER_GROUP = first(split(GROUP, ' ')) const TEST_TYPE_INFER = false # Test type stability over the tests From 6d807e68831fc59c9a9d6cf82e165fd810d67c74 Mon Sep 17 00:00:00 2001 From: Will Tebbutt Date: Thu, 22 Aug 2024 08:47:50 +0100 Subject: [PATCH 02/21] Improve perf --- src/space_time/pseudo_point.jl | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/src/space_time/pseudo_point.jl b/src/space_time/pseudo_point.jl index bcb90d5a..c5867df0 100644 --- a/src/space_time/pseudo_point.jl +++ b/src/space_time/pseudo_point.jl @@ -73,6 +73,7 @@ function AbstractGPs.elbo(fx::FiniteLTISDE, y::AbstractVector, z_r::AbstractVect k = fx_dtc.f.f.kernel Cf_diags = kernel_diagonals(k, fx_dtc.x) + # return Cf_diags # Transform a vector into a vector-of-vectors. y_vecs = restructure(y, lgssm.emissions) @@ -85,6 +86,7 @@ function AbstractGPs.elbo(fx::FiniteLTISDE, y::AbstractVector, z_r::AbstractVect end, zip(Σs, Cf_diags, marg_diags, y_vecs), ) + # return -sum(tmp) / 2 return logpdf(lgssm, y_vecs) - sum(tmp) / 2 end @@ -101,14 +103,8 @@ end function kernel_diagonals(k::DTCSeparable, x::RegularInTime) space_kernel = k.k.l - time_kernel = k.k.r - time_vars = kernelmatrix_diag(time_kernel, get_times(x)) - return Diagonal.( - kernelmatrix_diag.( - Ref(space_kernel), - x.vs - ) .* time_vars - ) + time_vars = kernelmatrix_diag(k.k.r, get_times(x)) + return map((v, tv) -> Diagonal(kernelmatrix_diag(space_kernel, v) * tv), x.vs, time_vars) end function kernel_diagonals(k::ScaledKernel, x::AbstractVector) From da5540dbf68f06d7c5242570f2cb8643a0a61829 Mon Sep 17 00:00:00 2001 From: Will Tebbutt Date: Thu, 22 Aug 2024 08:48:36 +0100 Subject: [PATCH 03/21] Progress --- examples/Project.toml | 2 +- examples/approx_space_time_learning.jl | 73 +++++++++++++++++++++++--- examples/exact_space_time_learning.jl | 24 +++++++-- examples/exact_time_learning.jl | 11 ++-- 4 files changed, 92 insertions(+), 18 deletions(-) diff --git a/examples/Project.toml b/examples/Project.toml index 10ecbb61..1d673c0b 100644 --- a/examples/Project.toml +++ b/examples/Project.toml @@ -6,5 +6,5 @@ Optim = "429524aa-4258-5aef-a3af-852621145aeb" ParameterHandling = "2412ca09-6db7-441c-8e3a-88d5709968c5" Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c" +Tapir = "07d77754-e150-4737-8c94-cd238a1fb45b" TemporalGPs = "e155a3c4-0841-43e1-8b83-a0e4f03cc18f" -Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" diff --git a/examples/approx_space_time_learning.jl b/examples/approx_space_time_learning.jl index ca007981..2748425c 100644 --- a/examples/approx_space_time_learning.jl +++ b/examples/approx_space_time_learning.jl @@ -13,7 +13,7 @@ using TemporalGPs: Separable, approx_posterior_marginals, RegularInTime # Load standard packages from the Julia ecosystem using Optim # Standard optimisation algorithms. using ParameterHandling # Helper functionality for dealing with model parameters. -using Zygote # Algorithmic Differentiation +using Tapir # Algorithmic Differentiation using ParameterHandling: flatten @@ -56,16 +56,72 @@ z_r = collect(range(-3.0, 3.0; length=5)); # Specify an objective function for Optim to minimise in terms of x and y. # We choose the usual negative log marginal likelihood (NLML). -function objective(params) - f = build_gp(params) - return -elbo(f(x, params.var_noise), y, z_r) +function make_objective(unpack, x, y, z_r) + function objective(flat_params) + params = unpack(flat_params) + f = build_gp(params) + return elbo(f(x, params.var_noise), y, z_r) + end + return objective end +objective = make_objective(unpack, x, y, z_r) -# Optimise using Optim. Takes a little while to compile because Zygote. +using Tapir: CoDual, primal + +Tapir.@is_primitive Tapir.MinimalCtx Tuple{typeof(TemporalGPs.time_exp), AbstractMatrix{<:Real}, Real} +function Tapir.rrule!!(::CoDual{typeof(TemporalGPs.time_exp)}, A::CoDual, t::CoDual{Float64}) + B_dB = Tapir.zero_fcodual(TemporalGPs.time_exp(primal(A), primal(t))) + B = primal(B_dB) + dB = tangent(B_dB) + time_exp_pb(::NoRData) = NoRData(), NoRData(), sum(dB .* (primal(A) * B)) + return B_dB, time_exp_pb +end + + + +using Random +# y = y +# z_r = z_r +# fx = build_gp(unpack(flat_initial_params))(x, params.var_noise) +# fx_dtc = TemporalGPs.dtcify(z_r, fx) +# lgssm = TemporalGPs.build_lgssm(fx_dtc) +# Σs = lgssm.emissions.fan_out.Q +# marg_diags = TemporalGPs.marginals_diag(lgssm) + +# k = fx_dtc.f.f.kernel +# Cf_diags = TemporalGPs.kernel_diagonals(k, fx_dtc.x) + +# # Transform a vector into a vector-of-vectors. +# y_vecs = TemporalGPs.restructure(y, lgssm.emissions) + +# tmp = TemporalGPs.zygote_friendly_map( +# ((Σ, Cf_diag, marg_diag, yn), ) -> begin +# Σ_, _ = TemporalGPs.fill_in_missings(Σ, yn) +# return sum(TemporalGPs.diag(Σ_ \ (Cf_diag - marg_diag.P))) - +# count(ismissing, yn) + size(Σ_, 1) +# end, +# zip(Σs, Cf_diags, marg_diags, y_vecs), +# ) + +# logpdf(lgssm, y_vecs) # this is the failing thing + +for _ in 1:10 + Tapir.TestUtils.test_rule( + Xoshiro(123456), objective, flat_initial_params; + perf_flag=:none, + interp=Tapir.TapirInterpreter(), + interface_only=false, + is_primitive=false, + safety_on=false, + ) +end + +# Optimise using Optim. +rule = Tapir.build_rrule(objective, flat_initial_params); training_results = Optim.optimize( - objective ∘ unpack, - θ -> only(Zygote.gradient(objective ∘ unpack, θ)), - flat_initial_params, + objective, + θ -> Tapir.value_and_gradient!!(rule, objective, θ)[2][2], + flat_initial_params + randn(4), # Add some noise to make learning non-trivial BFGS( alphaguess = Optim.LineSearches.InitialStatic(scaled=true), linesearch = Optim.LineSearches.BackTracking(), @@ -74,6 +130,7 @@ training_results = Optim.optimize( inplace=false, ); + # Extracting the final values of the parameters. # Should be close to truth. final_params = unpack(training_results.minimizer); diff --git a/examples/exact_space_time_learning.jl b/examples/exact_space_time_learning.jl index f9664ac3..c2ac99ca 100644 --- a/examples/exact_space_time_learning.jl +++ b/examples/exact_space_time_learning.jl @@ -13,7 +13,7 @@ using TemporalGPs: Separable, RectilinearGrid # Load standard packages from the Julia ecosystem using Optim # Standard optimisation algorithms. using ParameterHandling # Helper functionality for dealing with model parameters. -using Zygote # Algorithmic Differentiation +using Tapir # Algorithmic Differentiation # Declare model parameters using `ParameterHandling.jl` types. flat_initial_params, unflatten = ParameterHandling.flatten(( @@ -47,15 +47,29 @@ y = rand(build_gp(params)(x, 1e-4)); # Specify an objective function for Optim to minimise in terms of x and y. # We choose the usual negative log marginal likelihood (NLML). -function objective(params) +function objective(flat_params) + params = unpack(flat_params) f = build_gp(params) return -logpdf(f(x, params.var_noise), y) end -# Optimise using Optim. Takes a little while to compile because Zygote. +using Tapir: CoDual, primal + +Tapir.@is_primitive Tapir.MinimalCtx Tuple{typeof(TemporalGPs.time_exp), AbstractMatrix{<:Real}, Real} +function Tapir.rrule!!(::CoDual{typeof(TemporalGPs.time_exp)}, A::CoDual, t::CoDual{Float64}) + B_dB = Tapir.zero_fcodual(TemporalGPs.time_exp(primal(A), primal(t))) + B = primal(B_dB) + dB = tangent(B_dB) + time_exp_pb(::NoRData) = NoRData(), NoRData(), sum(dB .* (primal(A) * B)) + return B_dB, time_exp_pb +end + +rule = Tapir.build_rrule(objective, flat_initial_params); + +# Optimise using Optim. training_results = Optim.optimize( - objective ∘ unpack, - θ -> only(Zygote.gradient(objective ∘ unpack, θ)), + objective, + θ -> Tapir.value_and_gradient!!(rule, objective, θ)[2][2], flat_initial_params + randn(4), # Add some noise to make learning non-trivial BFGS( alphaguess = Optim.LineSearches.InitialStatic(scaled=true), diff --git a/examples/exact_time_learning.jl b/examples/exact_time_learning.jl index b4b05cb4..46a24203 100644 --- a/examples/exact_time_learning.jl +++ b/examples/exact_time_learning.jl @@ -12,7 +12,7 @@ using TemporalGPs: RegularSpacing # Load standard packages from the Julia ecosystem using Optim # Standard optimisation algorithms. using ParameterHandling # Helper functionality for dealing with model parameters. -using Zygote # Algorithmic Differentiation +using Tapir # Algorithmic Differentiation # Declare model parameters using `ParameterHandling.jl` types. # var_kernel is the variance of the kernel, λ the inverse length scale, and var_noise the @@ -42,15 +42,18 @@ y = rand(f(x, params.var_noise)); # Specify an objective function for Optim to minimise in terms of x and y. # We choose the usual negative log marginal likelihood (NLML). -function objective(params) +function objective(flat_params) + params = unpack(flat_params) f = build_gp(params) return -logpdf(f(x, params.var_noise), y) end +rule = Tapir.build_rrule(objective, flat_initial_params); + # Optimise using Optim. Zygote takes a little while to compile. training_results = Optim.optimize( - objective ∘ unpack, - θ -> only(Zygote.gradient(objective ∘ unpack, θ)), + objective, + θ -> Tapir.value_and_gradient!!(rule, objective, θ)[2][2], flat_initial_params .+ randn.(), # Perturb the parameters to make learning non-trivial BFGS( alphaguess = Optim.LineSearches.InitialStatic(scaled=true), From 56337d6a2df9f433d782ba93d98fb006c6113dd5 Mon Sep 17 00:00:00 2001 From: Will Tebbutt Date: Thu, 26 Sep 2024 23:23:19 +0100 Subject: [PATCH 04/21] Lots of changes --- .github/workflows/ci.yml | 2 - Project.toml | 24 +- README.md | 10 +- bench/Manifest.toml | 477 ---------------- bench/Project.toml | 1 - bench/lgssm.jl | 4 +- bench/predict.jl | 20 +- bench/single_output_gps.jl | 6 +- examples/Project.toml | 2 +- examples/approx_space_time_learning.jl | 123 ++-- examples/exact_space_time_learning.jl | 21 +- examples/exact_time_inference.jl | 2 +- examples/exact_time_learning.jl | 12 +- ext/TemporalGPsMooncakeExt.jl | 15 + src/TemporalGPs.jl | 5 - src/gp/lti_sde.jl | 71 +-- src/models/gauss_markov_model.jl | 32 -- src/models/lgssm.jl | 29 - src/models/linear_gaussian_conditionals.jl | 14 - src/models/missings.jl | 37 -- src/space_time/pseudo_point.jl | 54 +- src/space_time/rectilinear_grid.jl | 7 +- src/space_time/regular_in_time.jl | 11 +- src/space_time/to_gauss_markov.jl | 14 +- src/util/chainrules.jl | 409 -------------- src/util/gaussian.jl | 7 - src/util/harmonise.jl | 122 ---- src/util/regular_data.jl | 15 +- src/util/scan.jl | 137 ----- src/util/zygote_friendly_map.jl | 39 -- test/Project.toml | 29 - test/front_matter.jl | 39 ++ test/gp/lti_sde.jl | 111 ++-- test/models/lgssm.jl | 36 +- test/models/linear_gaussian_conditionals.jl | 6 +- test/models/missings.jl | 27 +- test/models/model_test_utils.jl | 39 -- test/runtests.jl | 123 ++-- test/space_time/pseudo_point.jl | 3 +- test/space_time/rectilinear_grid.jl | 12 - test/space_time/regular_in_time.jl | 2 - test/space_time/separable_kernel.jl | 3 - test/space_time/to_gauss_markov.jl | 44 +- test/test_util.jl | 585 +------------------- test/util/chainrules.jl | 117 ---- test/util/gaussian.jl | 10 - test/util/harmonise.jl | 57 -- test/util/mul.jl | 3 - test/util/regular_data.jl | 20 - test/util/scan.jl | 13 - test/util/zygote_friendly_map.jl | 4 - 51 files changed, 342 insertions(+), 2663 deletions(-) delete mode 100644 bench/Manifest.toml create mode 100644 ext/TemporalGPsMooncakeExt.jl delete mode 100644 src/util/chainrules.jl delete mode 100644 src/util/harmonise.jl delete mode 100644 test/Project.toml create mode 100644 test/front_matter.jl delete mode 100644 test/util/chainrules.jl delete mode 100644 test/util/harmonise.jl diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index ada7a038..5b467ae5 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -20,7 +20,6 @@ jobs: matrix: version: - '1' - - '1.6' os: - ubuntu-latest arch: @@ -28,7 +27,6 @@ jobs: group: - 'test util' - 'test models' - - 'test models-lgssm' - 'test gp' - 'test space_time' steps: diff --git a/Project.toml b/Project.toml index 4e0e7df3..1334e9a9 100644 --- a/Project.toml +++ b/Project.toml @@ -1,29 +1,41 @@ name = "TemporalGPs" uuid = "e155a3c4-0841-43e1-8b83-a0e4f03cc18f" -authors = ["willtebbutt and contributors"] -version = "0.6.8" +authors = ["Will Tebbutt and contributors"] +version = "0.7.0" [deps] AbstractGPs = "99985d1d-32ba-4be9-9821-2ec096f28918" Bessels = "0e736298-9ec6-45e8-9647-e4fc86a2fe38" BlockDiagonals = "0a1fb500-61f7-11e9-3c65-f5ef3456f9f0" -ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" KernelFunctions = "ec8451be-7e33-11e9-00cf-bbf324bd1392" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a" -Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" + +[weakdeps] +Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" + +[extensions] +TemporalGPsMooncakeExt = "Mooncake" [compat] AbstractGPs = "0.5.17" +BenchmarkTools = "1" Bessels = "0.2.8" BlockDiagonals = "0.1.7" -ChainRulesCore = "1" FillArrays = "0.13.0 - 0.13.7, 1" KernelFunctions = "0.9, 0.10.1" +Mooncake = "0.4" StaticArrays = "1" StructArrays = "0.5, 0.6" -Zygote = "0.6.65" julia = "1.6" + +[extras] +BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" +Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" + +[targets] +test = ["BenchmarkTools", "Mooncake"] diff --git a/README.md b/README.md index c566924a..c4b88ed1 100644 --- a/README.md +++ b/README.md @@ -41,7 +41,7 @@ f = to_sde(f_naive, SArrayStorage(Float64)) # Project onto finite-dimensional distribution as usual. # x = range(-5.0; step=0.1, length=10_000) -x = RegularSpacing(0.0, 0.1, 10_000) # Hack for Zygote. +x = RegularSpacing(0.0, 0.1, 10_000) # Hack for AD. fx = f(x, 0.1) # Sample from the prior as usual. @@ -63,7 +63,7 @@ rand(f_post(x)) logpdf(f_post(x), y) ``` -## Learning kernel parameters with [Optim.jl](https://github.com/JuliaNLSolvers/Optim.jl), [ParameterHandling.jl](https://github.com/invenia/ParameterHandling.jl), and [Zygote.jl](https://github.com/FluxML/Zygote.jl/) +## Learning kernel parameters with [Optim.jl](https://github.com/JuliaNLSolvers/Optim.jl), [ParameterHandling.jl](https://github.com/invenia/ParameterHandling.jl), and [Mooncake.jl](https://github.com/compintell/Mooncake.jl/) TemporalGPs.jl doesn't provide scikit-learn-like functionality to train your model (find good kernel parameter settings). Instead, we offer the functionality needed to easily implement your own training functionality using standard tools from the Julia ecosystem, as shown below. @@ -76,7 +76,7 @@ using TemporalGPs # Load standard packages from the Julia ecosystem using Optim # Standard optimisation algorithms. using ParameterHandling # Helper functionality for dealing with model parameters. -using Zygote # Algorithmic Differentiation +using Mooncake # Algorithmic Differentiation using ParameterHandling: flatten @@ -115,7 +115,7 @@ objective(params) # Optim.jl for more info on available optimisers and their properties. training_results = Optim.optimize( objective ∘ unpack, - θ -> only(Zygote.gradient(objective ∘ unpack, θ)), + θ -> only(Mooncake.gradient(objective ∘ unpack, θ)), flat_initial_params + randn(3), # Add some noise to make learning non-trivial BFGS( alphaguess = Optim.LineSearches.InitialStatic(scaled=true), @@ -152,7 +152,7 @@ This tells TemporalGPs that you want all parameters of `f` and anything derived "naive" timings are with the usual [AbstractGPs.jl](https://https://github.com/JuliaGaussianProcesses/AbstractGPs.jl/) inference routines, and is the default implementation for GPs. "lgssm" timings are conducted using `to_sde` with no additional arguments. "static-lgssm" uses the `SArrayStorage(Float64)` option discussed above. -Gradient computations use Zygote. Custom adjoints have been implemented to achieve this level of performance. +Gradient computations use Mooncake. Custom adjoints have been implemented to achieve this level of performance. diff --git a/bench/Manifest.toml b/bench/Manifest.toml deleted file mode 100644 index 3bcae4a2..00000000 --- a/bench/Manifest.toml +++ /dev/null @@ -1,477 +0,0 @@ -# This file is machine-generated - editing it directly is not advised - -[[AbstractFFTs]] -deps = ["LinearAlgebra"] -git-tree-sha1 = "051c95d6836228d120f5f4b984dd5aba1624f716" -uuid = "621f4979-c628-5d54-868e-fcf4e3e8185c" -version = "0.5.0" - -[[ArgCheck]] -git-tree-sha1 = "dedbbb2ddb876f899585c4ec4433265e3017215a" -uuid = "dce04be8-c92d-5529-be00-80e4d2c0e197" -version = "2.1.0" - -[[ArrayLayouts]] -deps = ["FillArrays", "LinearAlgebra"] -git-tree-sha1 = "951c3fc1ff93497c88fb1dfa893f4de55d0b38e3" -uuid = "4c555306-a7a7-4459-81d9-ec55ddd5c99a" -version = "0.3.8" - -[[BSON]] -git-tree-sha1 = "dd36d7cf3d185eeaaf64db902c15174b22f5dafb" -uuid = "fbb218c0-5317-5bc6-957e-2ee96dd4b1f0" -version = "0.2.6" - -[[Base64]] -uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f" - -[[BenchmarkTools]] -deps = ["JSON", "Logging", "Printf", "Statistics", "UUIDs"] -git-tree-sha1 = "9e62e66db34540a0c919d72172cc2f642ac71260" -uuid = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" -version = "0.5.0" - -[[BinaryProvider]] -deps = ["Libdl", "Logging", "SHA"] -git-tree-sha1 = "ecdec412a9abc8db54c0efc5548c64dfce072058" -uuid = "b99e7846-7c00-51b0-8f62-c81ae34c0232" -version = "0.5.10" - -[[BlockArrays]] -deps = ["ArrayLayouts", "Compat", "LinearAlgebra"] -git-tree-sha1 = "aabca4dc05a4bb8ac5e638940aa40e1951af6d32" -uuid = "8e7c35d0-a365-5155-bbbb-fb81a777f24e" -version = "0.12.11" - -[[BlockDiagonals]] -deps = ["FillArrays", "LinearAlgebra"] -git-tree-sha1 = "014018143ebbec43ac12d26f164eb3f049aa822f" -uuid = "0a1fb500-61f7-11e9-3c65-f5ef3456f9f0" -version = "0.1.6" - -[[CategoricalArrays]] -deps = ["Compat", "DataAPI", "Future", "JSON", "Missings", "Printf", "Reexport", "Statistics", "Unicode"] -git-tree-sha1 = "23d7324164c89638c18f6d7f90d972fa9c4fa9fb" -uuid = "324d7699-5711-5eae-9e2f-1d82baa6b597" -version = "0.7.7" - -[[ChainRules]] -deps = ["ChainRulesCore", "LinearAlgebra", "Reexport", "Requires", "Statistics"] -git-tree-sha1 = "76cd719cb7ab57bd2687dcb3b186c4f99820a79d" -uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" -version = "0.6.5" - -[[ChainRulesCore]] -deps = ["MuladdMacro"] -git-tree-sha1 = "c384e0e4fe6bfeb6bec0d41f71cc5e391cd110ba" -uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" -version = "0.8.1" - -[[CommonSubexpressions]] -deps = ["MacroTools", "Test"] -git-tree-sha1 = "7b8a93dba8af7e3b42fecabf646260105ac373f7" -uuid = "bbf7d656-a473-5ed7-a52c-81e309532950" -version = "0.3.0" - -[[Compat]] -deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "SHA", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"] -git-tree-sha1 = "215f1c81cfd1c5416cd78740bff8ef59b24cd7c0" -uuid = "34da2185-b29b-5c13-b0c7-acf172513d20" -version = "3.15.0" - -[[CompilerSupportLibraries_jll]] -deps = ["Libdl", "Pkg"] -git-tree-sha1 = "7c4f882c41faa72118841185afc58a2eb00ef612" -uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae" -version = "0.3.3+0" - -[[DataAPI]] -git-tree-sha1 = "176e23402d80e7743fc26c19c681bfb11246af32" -uuid = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a" -version = "1.3.0" - -[[DataFrames]] -deps = ["CategoricalArrays", "Compat", "DataAPI", "Future", "InvertedIndices", "IteratorInterfaceExtensions", "Missings", "PooledArrays", "Printf", "REPL", "Reexport", "SortingAlgorithms", "Statistics", "TableTraits", "Tables", "Unicode"] -git-tree-sha1 = "7d5bf815cc0b30253e3486e8ce2b93bf9d0faff6" -uuid = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" -version = "0.20.2" - -[[DataStructures]] -deps = ["Compat", "InteractiveUtils", "OrderedCollections"] -git-tree-sha1 = "65974157a18f4e19c07e5a94576328814ae23f9b" -uuid = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" -version = "0.18.3" - -[[DataValueInterfaces]] -git-tree-sha1 = "bfc1187b79289637fa0ef6d4436ebdfe6905cbd6" -uuid = "e2d170a0-9d28-54be-80f0-106bbe20a464" -version = "1.0.0" - -[[Dates]] -deps = ["Printf"] -uuid = "ade2ca70-3891-5945-98fb-dc099432e06a" - -[[DefaultApplication]] -deps = ["InteractiveUtils"] -git-tree-sha1 = "fc2b7122761b22c87fec8bf2ea4dc4563d9f8c24" -uuid = "3f0dd361-4fe0-5fc6-8523-80b14ec94d85" -version = "1.0.0" - -[[DelimitedFiles]] -deps = ["Mmap"] -uuid = "8bb1440f-4735-579b-a4ab-409b98df4dab" - -[[DiffResults]] -deps = ["StaticArrays"] -git-tree-sha1 = "da24935df8e0c6cf28de340b958f6aac88eaa0cc" -uuid = "163ba53b-c6d8-5494-b064-1a9d43ac40c5" -version = "1.0.2" - -[[DiffRules]] -deps = ["NaNMath", "Random", "SpecialFunctions"] -git-tree-sha1 = "eb0c34204c8410888844ada5359ac8b96292cfd1" -uuid = "b552c78f-8df3-52c6-915a-8e097449b14b" -version = "1.0.1" - -[[Distances]] -deps = ["LinearAlgebra", "Statistics"] -git-tree-sha1 = "bed62cc5afcff16de797a9f38fb358b74071f785" -uuid = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7" -version = "0.9.0" - -[[Distributed]] -deps = ["Random", "Serialization", "Sockets"] -uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b" - -[[Distributions]] -deps = ["FillArrays", "LinearAlgebra", "PDMats", "Printf", "QuadGK", "Random", "SpecialFunctions", "Statistics", "StatsBase", "StatsFuns"] -git-tree-sha1 = "9c41285c57c6e0d73a21ed4b65f6eec34805f937" -uuid = "31c24e10-a181-5473-b8eb-7969acd0382f" -version = "0.23.8" - -[[DocStringExtensions]] -deps = ["LibGit2", "Markdown", "Pkg", "Test"] -git-tree-sha1 = "50ddf44c53698f5e784bbebb3f4b21c5807401b1" -uuid = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" -version = "0.8.3" - -[[Documenter]] -deps = ["Base64", "DocStringExtensions", "InteractiveUtils", "JSON", "LibGit2", "Logging", "Markdown", "REPL", "Test", "Unicode"] -git-tree-sha1 = "d45c163c7a3ae293c15361acc52882c0f853f97c" -uuid = "e30172f5-a6a5-5a46-863b-614d45cd2de4" -version = "0.23.4" - -[[DrWatson]] -deps = ["Dates", "FileIO", "LibGit2", "Pkg", "Random", "Requires", "UnPack"] -git-tree-sha1 = "49e69db4a37a611f71f466f55d2bf516af42217f" -uuid = "634d3b9d-ee7a-5ddf-bec9-22491ea816e1" -version = "1.15.1" - -[[FileIO]] -deps = ["Pkg"] -git-tree-sha1 = "992b4aeb62f99b69fcf0cb2085094494cc05dfb3" -uuid = "5789e2e9-d7fb-5bc7-8068-2c6fae9b9549" -version = "1.4.3" - -[[FillArrays]] -deps = ["LinearAlgebra", "Random", "SparseArrays"] -git-tree-sha1 = "4863cbb7910079369e258dee4add9d06ead5063a" -uuid = "1a297f60-69ca-5386-bcde-b61e274b549b" -version = "0.8.14" - -[[ForwardDiff]] -deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "NaNMath", "Random", "SpecialFunctions", "StaticArrays"] -git-tree-sha1 = "1d090099fb82223abc48f7ce176d3f7696ede36d" -uuid = "f6369f11-7733-5829-9624-2563aa707210" -version = "0.10.12" - -[[Future]] -deps = ["Random"] -uuid = "9fa8497b-333b-5362-9e8d-4d0656e87820" - -[[IRTools]] -deps = ["InteractiveUtils", "MacroTools", "Test"] -git-tree-sha1 = "a8d88c05a23b44b4da6cf4fb5659e13ff95e0f47" -uuid = "7869d1d1-7146-5819-86e3-90919afe41df" -version = "0.4.1" - -[[InteractiveUtils]] -deps = ["Markdown"] -uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240" - -[[InvertedIndices]] -deps = ["Test"] -git-tree-sha1 = "15732c475062348b0165684ffe28e85ea8396afc" -uuid = "41ab1584-1d38-5bbf-9106-f11c6c58b48f" -version = "1.0.0" - -[[IteratorInterfaceExtensions]] -git-tree-sha1 = "a3f24677c21f5bbe9d2a714f95dcd58337fb2856" -uuid = "82899510-4779-5014-852e-03e436cf321d" -version = "1.0.0" - -[[JSON]] -deps = ["Dates", "Mmap", "Parsers", "Unicode"] -git-tree-sha1 = "b34d7cef7b337321e97d22242c3c2b91f476748e" -uuid = "682c06a0-de6a-54ab-a142-c8b1cf79cde6" -version = "0.21.0" - -[[Kronecker]] -deps = ["Documenter", "FillArrays", "LinearAlgebra", "Random", "SparseArrays", "StatsBase", "Test"] -git-tree-sha1 = "1c73eac80855eba3f67d223e0753ac3c97dde7cd" -uuid = "2c470bb0-bcc8-11e8-3dad-c9649493f05e" -version = "0.4.0" - -[[LibGit2]] -deps = ["Printf"] -uuid = "76f85450-5226-5b5a-8eaa-529ad045b433" - -[[Libdl]] -uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb" - -[[LinearAlgebra]] -deps = ["Libdl"] -uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" - -[[Logging]] -uuid = "56ddb016-857b-54e1-b83d-db4d58db5568" - -[[MacroTools]] -deps = ["Markdown", "Random"] -git-tree-sha1 = "f7d2e3f654af75f01ec49be82c231c382214223a" -uuid = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" -version = "0.5.5" - -[[Markdown]] -deps = ["Base64"] -uuid = "d6f4376e-aef5-505a-96c1-9c027394607a" - -[[Missings]] -deps = ["DataAPI"] -git-tree-sha1 = "ed61674a0864832495ffe0a7e889c0da76b0f4c8" -uuid = "e1d29d7a-bbdc-5cf2-9ac0-f12de2c33e28" -version = "0.4.4" - -[[Mmap]] -uuid = "a63ad114-7e13-5084-954f-fe012c677804" - -[[MuladdMacro]] -git-tree-sha1 = "c6190f9a7fc5d9d5915ab29f2134421b12d24a68" -uuid = "46d2c3a1-f734-5fdb-9937-b9b9aeba4221" -version = "0.2.2" - -[[NNlib]] -deps = ["BinaryProvider", "Libdl", "LinearAlgebra", "Requires", "Statistics"] -git-tree-sha1 = "d9f196d911f55aeaff11b11f681b135980783824" -uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" -version = "0.6.6" - -[[NaNMath]] -git-tree-sha1 = "c84c576296d0e2fbb3fc134d3e09086b3ea617cd" -uuid = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3" -version = "0.3.4" - -[[OpenSpecFun_jll]] -deps = ["CompilerSupportLibraries_jll", "Libdl", "Pkg"] -git-tree-sha1 = "d51c416559217d974a1113522d5919235ae67a87" -uuid = "efe28fd5-8261-553b-a9e1-b2916fc3738e" -version = "0.5.3+3" - -[[OrderedCollections]] -git-tree-sha1 = "293b70ac1780f9584c89268a6e2a560d938a7065" -uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" -version = "1.3.0" - -[[PDMats]] -deps = ["LinearAlgebra", "SparseArrays", "SuiteSparse", "Test"] -git-tree-sha1 = "b3405086eb6a974eba1958923d46bc0e1c2d2d63" -uuid = "90014a1f-27ba-587c-ab20-58faa44d9150" -version = "0.10.0" - -[[PGFPlotsX]] -deps = ["ArgCheck", "DataStructures", "Dates", "DefaultApplication", "DocStringExtensions", "MacroTools", "Parameters", "Requires", "Tables"] -git-tree-sha1 = "1adde3d07cce96b6a3bb88572612db4bd9d6153b" -uuid = "8314cec4-20b6-5062-9cdb-752b83310925" -version = "1.2.10" - -[[Parameters]] -deps = ["OrderedCollections", "UnPack"] -git-tree-sha1 = "38b2e970043613c187bd56a995fe2e551821eb4a" -uuid = "d96e819e-fc66-5662-9728-84c9c7592b0a" -version = "0.12.1" - -[[Parsers]] -deps = ["Dates", "Test"] -git-tree-sha1 = "8077624b3c450b15c087944363606a6ba12f925e" -uuid = "69de0a69-1ddd-5017-9359-2bf0b02dc9f0" -version = "1.0.10" - -[[Pkg]] -deps = ["Dates", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "UUIDs"] -uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" - -[[PooledArrays]] -deps = ["DataAPI"] -git-tree-sha1 = "b1333d4eced1826e15adbdf01a4ecaccca9d353c" -uuid = "2dfb63ee-cc39-5dd5-95bd-886bf059d720" -version = "0.5.3" - -[[Printf]] -deps = ["Unicode"] -uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7" - -[[ProgressMeter]] -deps = ["Distributed", "Printf"] -git-tree-sha1 = "2de4cddc0ceeddafb6b143b5b6cd9c659b64507c" -uuid = "92933f4c-e287-5a05-a399-4b506db050ca" -version = "1.3.2" - -[[QuadGK]] -deps = ["DataStructures", "LinearAlgebra"] -git-tree-sha1 = "12fbe86da16df6679be7521dfb39fbc861e1dc7b" -uuid = "1fd47b50-473d-5c70-9696-f719f8f3bcdc" -version = "2.4.1" - -[[REPL]] -deps = ["InteractiveUtils", "Markdown", "Sockets"] -uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb" - -[[Random]] -deps = ["Serialization"] -uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" - -[[RecipesBase]] -git-tree-sha1 = "58de8f7e33b7fda6ee39eff65169cd1e19d0c107" -uuid = "3cdcf5f2-1ef4-517c-9805-6587b60abb01" -version = "1.0.2" - -[[Reexport]] -deps = ["Pkg"] -git-tree-sha1 = "7b1d07f411bc8ddb7977ec7f377b97b158514fe0" -uuid = "189a3867-3050-52da-a836-e630ba90ab69" -version = "0.2.0" - -[[Requires]] -deps = ["UUIDs"] -git-tree-sha1 = "8c08d0c7812169e438a8478dae2a529377ad13f7" -uuid = "ae029012-a4dd-5104-9daa-d747884805df" -version = "1.0.2" - -[[Rmath]] -deps = ["Random", "Rmath_jll"] -git-tree-sha1 = "86c5647b565873641538d8f812c04e4c9dbeb370" -uuid = "79098fc4-a85e-5d69-aa6a-4863f24498fa" -version = "0.6.1" - -[[Rmath_jll]] -deps = ["Libdl", "Pkg"] -git-tree-sha1 = "d76185aa1f421306dec73c057aa384bad74188f0" -uuid = "f50d1b31-88e8-58de-be2c-1cc44531875f" -version = "0.2.2+1" - -[[SHA]] -uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce" - -[[Serialization]] -uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b" - -[[SharedArrays]] -deps = ["Distributed", "Mmap", "Random", "Serialization"] -uuid = "1a1011a3-84de-559e-8e89-a11a2f7dc383" - -[[Sockets]] -uuid = "6462fe0b-24de-5631-8697-dd941f90decc" - -[[SortingAlgorithms]] -deps = ["DataStructures", "Random", "Test"] -git-tree-sha1 = "03f5898c9959f8115e30bc7226ada7d0df554ddd" -uuid = "a2af1166-a08f-5f64-846c-94a0d3cef48c" -version = "0.3.1" - -[[SparseArrays]] -deps = ["LinearAlgebra", "Random"] -uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" - -[[SpecialFunctions]] -deps = ["OpenSpecFun_jll"] -git-tree-sha1 = "d8d8b8a9f4119829410ecd706da4cc8594a1e020" -uuid = "276daf66-3868-5448-9aa4-cd146d93841b" -version = "0.10.3" - -[[StaticArrays]] -deps = ["LinearAlgebra", "Random", "Statistics"] -git-tree-sha1 = "5a3bcb6233adabde68ebc97be66e95dcb787424c" -uuid = "90137ffa-7385-5640-81b9-e52037218182" -version = "0.12.1" - -[[Statistics]] -deps = ["LinearAlgebra", "SparseArrays"] -uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" - -[[StatsBase]] -deps = ["DataAPI", "DataStructures", "LinearAlgebra", "Missings", "Printf", "Random", "SortingAlgorithms", "SparseArrays", "Statistics"] -git-tree-sha1 = "d72a47c47c522e283db774fc8c459dd5ed773710" -uuid = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" -version = "0.33.1" - -[[StatsFuns]] -deps = ["Rmath", "SpecialFunctions"] -git-tree-sha1 = "04a5a8e6ab87966b43f247920eab053fd5fdc925" -uuid = "4c63d2b9-4356-54db-8cca-17b64c39e42c" -version = "0.9.5" - -[[Stheno]] -deps = ["BlockArrays", "Distances", "Distributions", "FillArrays", "LinearAlgebra", "MacroTools", "Random", "RecipesBase", "Requires", "Statistics", "Zygote", "ZygoteRules"] -git-tree-sha1 = "7b672346e683704de182b663e42e5cd044529cfe" -uuid = "8188c328-b5d6-583d-959b-9690869a5511" -version = "0.6.6" - -[[SuiteSparse]] -deps = ["Libdl", "LinearAlgebra", "Serialization", "SparseArrays"] -uuid = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9" - -[[TableTraits]] -deps = ["IteratorInterfaceExtensions"] -git-tree-sha1 = "b1ad568ba658d8cbb3b892ed5380a6f3e781a81e" -uuid = "3783bdb8-4a98-5b6b-af9a-565f29a5fe9c" -version = "1.0.0" - -[[Tables]] -deps = ["DataAPI", "DataValueInterfaces", "IteratorInterfaceExtensions", "LinearAlgebra", "TableTraits", "Test"] -git-tree-sha1 = "b7f762e9820b7fab47544c36f26f54ac59cf8abf" -uuid = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" -version = "1.0.5" - -[[TemporalGPs]] -deps = ["BlockArrays", "BlockDiagonals", "Distributions", "FillArrays", "Kronecker", "LinearAlgebra", "Random", "StaticArrays", "Stheno", "Zygote", "ZygoteRules"] -path = ".." -uuid = "e155a3c4-0841-43e1-8b83-a0e4f03cc18f" -version = "0.3.3" - -[[Test]] -deps = ["Distributed", "InteractiveUtils", "Logging", "Random"] -uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40" - -[[UUIDs]] -deps = ["Random", "SHA"] -uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" - -[[UnPack]] -git-tree-sha1 = "387c1f73762231e86e0c9c5443ce3b4a0a9a0c2b" -uuid = "3a884ed6-31ef-47d7-9d2a-63182c4928ed" -version = "1.0.2" - -[[Unicode]] -uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5" - -[[Zygote]] -deps = ["AbstractFFTs", "ArrayLayouts", "ChainRules", "FillArrays", "ForwardDiff", "Future", "IRTools", "InteractiveUtils", "LinearAlgebra", "MacroTools", "NNlib", "Random", "Requires", "Statistics", "ZygoteRules"] -git-tree-sha1 = "2e2c82549fb0414df10469082fd001e2ede8547c" -uuid = "e88e6eb3-aa80-5325-afca-941959d7151f" -version = "0.4.22" - -[[ZygoteRules]] -deps = ["MacroTools"] -git-tree-sha1 = "b3b4882cc9accf6731a08cc39543fbc6b669dca8" -uuid = "700de1a5-db45-46bc-99cf-38207098b444" -version = "0.2.0" diff --git a/bench/Project.toml b/bench/Project.toml index e8939154..cc8e279b 100644 --- a/bench/Project.toml +++ b/bench/Project.toml @@ -12,4 +12,3 @@ ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Stheno = "8188c328-b5d6-583d-959b-9690869a5511" TemporalGPs = "e155a3c4-0841-43e1-8b83-a0e4f03cc18f" -Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" diff --git a/bench/lgssm.jl b/bench/lgssm.jl index cd84c750..9fd9ceb2 100644 --- a/bench/lgssm.jl +++ b/bench/lgssm.jl @@ -3,7 +3,7 @@ Pkg.activate(".") Pkg.instantiate() using BenchmarkTools, BlockDiagonals, FillArrays, LinearAlgebra, Random, Stheno, - TemporalGPs, Zygote + TemporalGPs, Mooncake using DataFrames, DrWatson, PGFPlotsX @@ -131,7 +131,7 @@ let # Benchmark logpdf evaluation and gradient evaluation. logpdf_results = @benchmark logpdf($ft, $y) - logpdf_gradient_results = @benchmark Zygote.gradient(logpdf, $ft, $y) + logpdf_gradient_results = @benchmark Mooncake.gradient(logpdf, $ft, $y) # Save results to disk. wsave( diff --git a/bench/predict.jl b/bench/predict.jl index 40312edd..cd075815 100644 --- a/bench/predict.jl +++ b/bench/predict.jl @@ -435,7 +435,7 @@ end # using BenchmarkTools, FillArrays, Kronecker, LinearAlgebra, Random, Stheno, - TemporalGPs, Zygote + TemporalGPs, Mooncake using TemporalGPs: predict @@ -468,11 +468,11 @@ A_dense = collect(A); # using ProfileView # @profview [predict(mf, Pf, A, a, Q) for _ in 1:10] -@benchmark Zygote.pullback(predict, $mf, $Pf, $A, $a, $Q) -@benchmark Zygote.pullback(predict, $mf, $Pf, $A_dense, $a, $Q) +@benchmark Mooncake.pullback(predict, $mf, $Pf, $A, $a, $Q) +@benchmark Mooncake.pullback(predict, $mf, $Pf, $A_dense, $a, $Q) -_, back = Zygote.pullback(predict, mf, Pf, A, a, Q); -_, back_dense = Zygote.pullback(predict, mf, Pf, A_dense, a, Q); +_, back = Mooncake.pullback(predict, mf, Pf, A, a, Q); +_, back_dense = Mooncake.pullback(predict, mf, Pf, A_dense, a, Q); mp = copy(mf); Pp = collect(Pf); @@ -509,7 +509,7 @@ T = Float64; # using BenchmarkTools, BlockDiagonals, FillArrays, Kronecker, LinearAlgebra, Random, Stheno, - TemporalGPs, Zygote + TemporalGPs, Mooncake using TemporalGPs: predict @@ -550,11 +550,11 @@ Q_dense = collect(Q); # using ProfileView # @profview [predict(mf, Pf, A, a, Q) for _ in 1:10] -@benchmark Zygote.pullback(predict, $mf, $Pf, $A, $a, $Q) -@benchmark Zygote.pullback(predict, $mf, $Pf, $A_dense, $a, $Q_dense) +@benchmark Mooncake.pullback(predict, $mf, $Pf, $A, $a, $Q) +@benchmark Mooncake.pullback(predict, $mf, $Pf, $A_dense, $a, $Q_dense) -_, back = Zygote.pullback(predict, mf, Pf, A, a, Q); -_, back_dense = Zygote.pullback(predict, mf, Pf, A_dense, a, Q_dense); +_, back = Mooncake.pullback(predict, mf, Pf, A, a, Q); +_, back_dense = Mooncake.pullback(predict, mf, Pf, A_dense, a, Q_dense); mp = copy(mf); Pp = collect(Pf); diff --git a/bench/single_output_gps.jl b/bench/single_output_gps.jl index ede0c401..22130b1d 100644 --- a/bench/single_output_gps.jl +++ b/bench/single_output_gps.jl @@ -14,7 +14,7 @@ Pkg.instantiate(); using Revise using DrWatson, Stheno, BenchmarkTools, PGFPlotsX, ProgressMeter, TemporalGPs, Random, - DataFrames, Zygote + DataFrames, Mooncake using DrWatson: @dict, @tagsave @@ -146,14 +146,14 @@ let # Generate results including construction of GP. results = @benchmark(build_and_logpdf($(impl.val), $k, $σ², $l, $x, $σ²_n, $y)) - grad_results = @benchmark(Zygote.gradient( + grad_results = @benchmark(Mooncake.gradient( $build_and_logpdf, $(impl.val), $k, $σ², $l, $x, $σ²_n, $y, )) # Generate results excluding construction of GP. fx = build(impl.val, k, σ², l, x, σ²_n) no_build_results = @benchmark(logpdf($fx, $y)) - no_build_grad_results = @benchmark(Zygote.gradient(logpdf, $fx, $y)) + no_build_grad_results = @benchmark(Mooncake.gradient(logpdf, $fx, $y)) # Save results in predictable location. @tagsave( diff --git a/examples/Project.toml b/examples/Project.toml index 1d673c0b..79bd1bd5 100644 --- a/examples/Project.toml +++ b/examples/Project.toml @@ -2,9 +2,9 @@ AbstractGPs = "99985d1d-32ba-4be9-9821-2ec096f28918" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" KernelFunctions = "ec8451be-7e33-11e9-00cf-bbf324bd1392" +Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" Optim = "429524aa-4258-5aef-a3af-852621145aeb" ParameterHandling = "2412ca09-6db7-441c-8e3a-88d5709968c5" Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c" -Tapir = "07d77754-e150-4737-8c94-cd238a1fb45b" TemporalGPs = "e155a3c4-0841-43e1-8b83-a0e4f03cc18f" diff --git a/examples/approx_space_time_learning.jl b/examples/approx_space_time_learning.jl index 2748425c..b05c88b9 100644 --- a/examples/approx_space_time_learning.jl +++ b/examples/approx_space_time_learning.jl @@ -1,6 +1,6 @@ # This is an extended version of approx_space_time_inference.jl. It combines it with -# Optim + ParameterHandling + Zygote to learn the kernel parameters. -# If you understand how to use Optim + ParameterHandling + Zygote for an AbstractGP, +# Optim + ParameterHandling + Mooncake to learn the kernel parameters. +# If you understand how to use Optim + ParameterHandling + Mooncake for an AbstractGP, # e.g. that shown on the README for this package, and how approx_space_time_inference.jl # works, then you should understand this file. @@ -13,7 +13,7 @@ using TemporalGPs: Separable, approx_posterior_marginals, RegularInTime # Load standard packages from the Julia ecosystem using Optim # Standard optimisation algorithms. using ParameterHandling # Helper functionality for dealing with model parameters. -using Tapir # Algorithmic Differentiation +using Mooncake # Algorithmic Differentiation using ParameterHandling: flatten @@ -54,73 +54,72 @@ y = sin.(first.(xs)) .+ cos.(last.(xs)) + sqrt.(params.var_noise) .* randn(lengt # Spatial pseudo-point inputs. z_r = collect(range(-3.0, 3.0; length=5)); -# Specify an objective function for Optim to minimise in terms of x and y. -# We choose the usual negative log marginal likelihood (NLML). -function make_objective(unpack, x, y, z_r) - function objective(flat_params) - params = unpack(flat_params) - f = build_gp(params) - return elbo(f(x, params.var_noise), y, z_r) - end - return objective +# # Specify an objective function for Optim to minimise in terms of x and y. +# # We choose the usual negative log marginal likelihood (NLML). +# function make_objective(unpack, x, y, z_r) +# function objective(flat_params) +# params = unpack(flat_params) +# f = build_gp(params) +# return elbo(f(x, params.var_noise), y, z_r) +# end +# return objective +# end +# objective = make_objective(unpack, x, y, z_r) + +function objective(flat_params) + params = unpack(flat_params) + f = build_gp(params) + return -elbo(f(x, params.var_noise), y, z_r) end -objective = make_objective(unpack, x, y, z_r) -using Tapir: CoDual, primal +# using Random +# # y = y +# # z_r = z_r +# # fx = build_gp(unpack(flat_initial_params))(x, params.var_noise) +# # fx_dtc = TemporalGPs.dtcify(z_r, fx) +# # lgssm = TemporalGPs.build_lgssm(fx_dtc) +# # Σs = lgssm.emissions.fan_out.Q +# # marg_diags = TemporalGPs.marginals_diag(lgssm) + +# # k = fx_dtc.f.f.kernel +# # Cf_diags = TemporalGPs.kernel_diagonals(k, fx_dtc.x) + +# # # Transform a vector into a vector-of-vectors. +# # y_vecs = TemporalGPs.restructure(y, lgssm.emissions) + +# # tmp = TemporalGPs.zygote_friendly_map( +# # ((Σ, Cf_diag, marg_diag, yn), ) -> begin +# # Σ_, _ = TemporalGPs.fill_in_missings(Σ, yn) +# # return sum(TemporalGPs.diag(Σ_ \ (Cf_diag - marg_diag.P))) - +# # count(ismissing, yn) + size(Σ_, 1) +# # end, +# # zip(Σs, Cf_diags, marg_diags, y_vecs), +# # ) + +# # logpdf(lgssm, y_vecs) # this is the failing thing + +# # for _ in 1:10 +# # Tapir.TestUtils.test_rule( +# # Xoshiro(123456), objective, flat_initial_params; +# # perf_flag=:none, +# # interp=Tapir.TapirInterpreter(), +# # interface_only=false, +# # is_primitive=false, +# # safety_on=false, +# # ) +# # end -Tapir.@is_primitive Tapir.MinimalCtx Tuple{typeof(TemporalGPs.time_exp), AbstractMatrix{<:Real}, Real} -function Tapir.rrule!!(::CoDual{typeof(TemporalGPs.time_exp)}, A::CoDual, t::CoDual{Float64}) - B_dB = Tapir.zero_fcodual(TemporalGPs.time_exp(primal(A), primal(t))) - B = primal(B_dB) - dB = tangent(B_dB) - time_exp_pb(::NoRData) = NoRData(), NoRData(), sum(dB .* (primal(A) * B)) - return B_dB, time_exp_pb +# Optimise using Optim. +function objective_grad(rule, flat_params) + return Mooncake.value_and_gradient!!(rule, objective, flat_params)[2][2] end +@info "running objective" +@show objective(flat_initial_params) - -using Random -# y = y -# z_r = z_r -# fx = build_gp(unpack(flat_initial_params))(x, params.var_noise) -# fx_dtc = TemporalGPs.dtcify(z_r, fx) -# lgssm = TemporalGPs.build_lgssm(fx_dtc) -# Σs = lgssm.emissions.fan_out.Q -# marg_diags = TemporalGPs.marginals_diag(lgssm) - -# k = fx_dtc.f.f.kernel -# Cf_diags = TemporalGPs.kernel_diagonals(k, fx_dtc.x) - -# # Transform a vector into a vector-of-vectors. -# y_vecs = TemporalGPs.restructure(y, lgssm.emissions) - -# tmp = TemporalGPs.zygote_friendly_map( -# ((Σ, Cf_diag, marg_diag, yn), ) -> begin -# Σ_, _ = TemporalGPs.fill_in_missings(Σ, yn) -# return sum(TemporalGPs.diag(Σ_ \ (Cf_diag - marg_diag.P))) - -# count(ismissing, yn) + size(Σ_, 1) -# end, -# zip(Σs, Cf_diags, marg_diags, y_vecs), -# ) - -# logpdf(lgssm, y_vecs) # this is the failing thing - -for _ in 1:10 - Tapir.TestUtils.test_rule( - Xoshiro(123456), objective, flat_initial_params; - perf_flag=:none, - interp=Tapir.TapirInterpreter(), - interface_only=false, - is_primitive=false, - safety_on=false, - ) -end - -# Optimise using Optim. -rule = Tapir.build_rrule(objective, flat_initial_params); training_results = Optim.optimize( objective, - θ -> Tapir.value_and_gradient!!(rule, objective, θ)[2][2], + Base.Fix1(objective_grad, Mooncake.build_rrule(objective, flat_initial_params)), flat_initial_params + randn(4), # Add some noise to make learning non-trivial BFGS( alphaguess = Optim.LineSearches.InitialStatic(scaled=true), diff --git a/examples/exact_space_time_learning.jl b/examples/exact_space_time_learning.jl index c2ac99ca..a98e6a00 100644 --- a/examples/exact_space_time_learning.jl +++ b/examples/exact_space_time_learning.jl @@ -1,6 +1,6 @@ # This is an extended version of exact_space_time_inference.jl. It combines it with -# Optim + ParameterHandling + Zygote to learn the kernel parameters. -# If you understand how to use Optim + ParameterHandling + Zygote for an AbstractGP, +# Optim + ParameterHandling + Mooncake to learn the kernel parameters. +# If you understand how to use Optim + ParameterHandling + Mooncake for an AbstractGP, # e.g. that shown on the README for this package, and how exact_space_time_inference.jl # works, then you should understand this file. @@ -13,7 +13,7 @@ using TemporalGPs: Separable, RectilinearGrid # Load standard packages from the Julia ecosystem using Optim # Standard optimisation algorithms. using ParameterHandling # Helper functionality for dealing with model parameters. -using Tapir # Algorithmic Differentiation +using Mooncake # Algorithmic Differentiation # Declare model parameters using `ParameterHandling.jl` types. flat_initial_params, unflatten = ParameterHandling.flatten(( @@ -53,23 +53,14 @@ function objective(flat_params) return -logpdf(f(x, params.var_noise), y) end -using Tapir: CoDual, primal - -Tapir.@is_primitive Tapir.MinimalCtx Tuple{typeof(TemporalGPs.time_exp), AbstractMatrix{<:Real}, Real} -function Tapir.rrule!!(::CoDual{typeof(TemporalGPs.time_exp)}, A::CoDual, t::CoDual{Float64}) - B_dB = Tapir.zero_fcodual(TemporalGPs.time_exp(primal(A), primal(t))) - B = primal(B_dB) - dB = tangent(B_dB) - time_exp_pb(::NoRData) = NoRData(), NoRData(), sum(dB .* (primal(A) * B)) - return B_dB, time_exp_pb +function objective_grad(rule, flat_params) + return Mooncake.value_and_gradient!!(rule, objective, flat_params)[2][2] end -rule = Tapir.build_rrule(objective, flat_initial_params); - # Optimise using Optim. training_results = Optim.optimize( objective, - θ -> Tapir.value_and_gradient!!(rule, objective, θ)[2][2], + Base.Fix1(objective_grad, Mooncake.build_rrule(objective, flat_initial_params)), flat_initial_params + randn(4), # Add some noise to make learning non-trivial BFGS( alphaguess = Optim.LineSearches.InitialStatic(scaled=true), diff --git a/examples/exact_time_inference.jl b/examples/exact_time_inference.jl index ad6a0d68..0b7d9e6d 100644 --- a/examples/exact_time_inference.jl +++ b/examples/exact_time_inference.jl @@ -4,7 +4,7 @@ using TemporalGPs # Utilising TemporalGPs.jl to work with AbstractGPs requires minimal modification to the # GP objects from AbstractGPs that you are used to. The primary differences are # 1. RegularSpacing is a useful type. It's basically a `range` that's hacked together to -# work nicely with Zygote.jl. At some point, it will hopefully disappear. +# work nicely with Mooncake.jl. At some point, it will hopefully disappear. # 2. Call `to_sde` on your AbstractGP object to say "use TemporalGPs.jl to do inference". # This is an example of a very, very noise regression problem. diff --git a/examples/exact_time_learning.jl b/examples/exact_time_learning.jl index 46a24203..d8808157 100644 --- a/examples/exact_time_learning.jl +++ b/examples/exact_time_learning.jl @@ -1,5 +1,5 @@ # This is an extended version of exact_time_inference.jl. It combines it with -# Optim + ParameterHandling + Zygote to learn the kernel parameters. +# Optim + ParameterHandling + Mooncake to learn the kernel parameters. # Each of these other packages know nothing about TemporalGPs, they're just general-purpose # packages which play nicely with TemporalGPs (and AbstractGPs). @@ -12,7 +12,7 @@ using TemporalGPs: RegularSpacing # Load standard packages from the Julia ecosystem using Optim # Standard optimisation algorithms. using ParameterHandling # Helper functionality for dealing with model parameters. -using Tapir # Algorithmic Differentiation +using Mooncake # Algorithmic Differentiation # Declare model parameters using `ParameterHandling.jl` types. # var_kernel is the variance of the kernel, λ the inverse length scale, and var_noise the @@ -48,12 +48,14 @@ function objective(flat_params) return -logpdf(f(x, params.var_noise), y) end -rule = Tapir.build_rrule(objective, flat_initial_params); +function objective_grad(rule, flat_params) + return Mooncake.value_and_gradient!!(rule, objective, flat_params)[2][2] +end -# Optimise using Optim. Zygote takes a little while to compile. +# Optimise using Optim. Mooncake takes a little while to compile. training_results = Optim.optimize( objective, - θ -> Tapir.value_and_gradient!!(rule, objective, θ)[2][2], + Base.Fix1(objective_grad, Mooncake.build_rrule(objective, flat_initial_params)), flat_initial_params .+ randn.(), # Perturb the parameters to make learning non-trivial BFGS( alphaguess = Optim.LineSearches.InitialStatic(scaled=true), diff --git a/ext/TemporalGPsMooncakeExt.jl b/ext/TemporalGPsMooncakeExt.jl new file mode 100644 index 00000000..4be3a2e9 --- /dev/null +++ b/ext/TemporalGPsMooncakeExt.jl @@ -0,0 +1,15 @@ +module TemporalGPsMooncakeExt + +using Mooncake, TemporalGPs +import Mooncake: rrule!!, CoDual, primal, @is_primitive, zero_fcodual, MinimalCtx + +@is_primitive MinimalCtx Tuple{typeof(TemporalGPs.time_exp), AbstractMatrix{<:Real}, Real} +function rrule!!(::CoDual{typeof(TemporalGPs.time_exp)}, A::CoDual, t::CoDual{Float64}) + B_dB = zero_fcodual(TemporalGPs.time_exp(primal(A), primal(t))) + B = primal(B_dB) + dB = tangent(B_dB) + time_exp_pb(::NoRData) = NoRData(), NoRData(), sum(dB .* (primal(A) * B)) + return B_dB, time_exp_pb +end + +end diff --git a/src/TemporalGPs.jl b/src/TemporalGPs.jl index 7e9cf09c..c447f6fa 100644 --- a/src/TemporalGPs.jl +++ b/src/TemporalGPs.jl @@ -3,15 +3,12 @@ module TemporalGPs using AbstractGPs using Bessels: besseli using BlockDiagonals - using ChainRulesCore - import ChainRulesCore: rrule using FillArrays using LinearAlgebra using KernelFunctions using Random using StaticArrays using StructArrays - using Zygote using FillArrays: AbstractFill @@ -36,12 +33,10 @@ module TemporalGPs ApproxPeriodicKernel # Various bits-and-bobs. Often commiting some type piracy. - include(joinpath("util", "harmonise.jl")) include(joinpath("util", "linear_algebra.jl")) include(joinpath("util", "scan.jl")) include(joinpath("util", "zygote_friendly_map.jl")) - include(joinpath("util", "chainrules.jl")) include(joinpath("util", "gaussian.jl")) include(joinpath("util", "mul.jl")) include(joinpath("util", "storage_types.jl")) diff --git a/src/gp/lti_sde.jl b/src/gp/lti_sde.jl index 03d7f7ed..2b521b18 100644 --- a/src/gp/lti_sde.jl +++ b/src/gp/lti_sde.jl @@ -88,7 +88,7 @@ get_kernel(f::GP) = f.kernel function build_emissions( (Hs, hs)::Tuple{AbstractVector, AbstractVector}, Σs::AbstractVector, ) - Hst = _map(adjoint, Hs) + Hst = map(adjoint, Hs) return StructArray{get_type(Hst, hs, Σs)}((Hst, hs, Σs)) end @@ -108,10 +108,6 @@ function get_type(Hs_prime, hs::AbstractVector{<:AbstractVector}, Σs) return T end -@inline function Zygote.wrap_chainrules_output(x::NamedTuple) - return map(Zygote.wrap_chainrules_output, x) -end - # Constructor for combining kernel and mean functions function lgssm_components( ::ZeroMean, k::Kernel, t::AbstractVector, storage_type::StorageType @@ -135,14 +131,17 @@ function add_proj_mean(hs::AbstractVector, m) return map((h, m) -> h + vcat(m, Zeros(length(h) - 1)), hs, m) end +# Really just a hook for AD. +time_exp(A, t) = exp(A * t) + # Generic constructors for base kernels. function broadcast_components((F, q, H)::Tuple, x0::Gaussian, t::AbstractVector{<:Real}, ::StorageType{T}) where {T} P = Symmetric(x0.P) t = vcat([first(t) - 1], t) - As = _map(Δt -> time_exp(F, T(Δt)), diff(t)) + As = map(Δt -> time_exp(F, T(Δt)), diff(t)) as = Fill(Zeros{T}(size(first(As), 1)), length(As)) - Qs = _map(A -> P - A * P * A', As) + Qs = map(A -> P - A * P * A', As) Hs = Fill(H, length(As)) hs = Fill(zero(T), length(As)) As, as, Qs, Hs, hs @@ -152,7 +151,7 @@ function broadcast_components((F, q, H)::Tuple, x0::Gaussian, t::Union{StepRange P = Symmetric(x0.P) A = time_exp(F, T(step(t))) As = Fill(A, length(t)) - as = @ignore_derivatives(Fill(Zeros{T}(size(F, 1)), length(t))) + as = Fill(Zeros{T}(size(F, 1)), length(t)) Q = Symmetric(P) - A * Symmetric(P) * A' Qs = Fill(Q, length(t)) Hs = Fill(H, length(t)) @@ -342,11 +341,11 @@ function lgssm_components(k::ScaledKernel, ts::AbstractVector, storage_type::Sto end function _scale_emission_projections((Hs, hs)::Tuple{AbstractVector, AbstractVector}, σ::Real) - return _map(H->σ * H, Hs), _map(h->σ * h, hs) + return map(H->σ * H, Hs), map(h->σ * h, hs) end function _scale_emission_projections((Cs, cs, Hs, hs), σ) - return (Cs, cs, _map(H -> σ * H, Hs), _map(h -> σ * h, hs)) + return (Cs, cs, map(H -> σ * H, Hs), map(h -> σ * h, hs)) end # Stretched @@ -412,9 +411,9 @@ function lgssm_components(k::KernelSum, ts::AbstractVector, storage_type::Storag emission_proj_kernels = getindex.(lgssms, 4) x0_kernels = getindex.(lgssms, 5) - As = _map(block_diagonal, As_kernels...) - as = _map(vcat, as_kernels...) - Qs = _map(block_diagonal, Qs_kernels...) + As = map(block_diagonal, As_kernels...) + as = map(vcat, as_kernels...) + Qs = map(block_diagonal, Qs_kernels...) emission_projections = _sum_emission_projections(emission_proj_kernels...) x0 = Gaussian(mapreduce(x -> getproperty(x, :m), vcat, x0_kernels), block_diagonal(getproperty.(x0_kernels, :P)...)) return As, as, Qs, emission_projections, x0 @@ -431,52 +430,18 @@ function _sum_emission_projections( cs = getindex.(Cs_cs_Hs_hs, 2) Hs = getindex.(Cs_cs_Hs_hs, 3) hs = getindex.(Cs_cs_Hs_hs, 4) - C = _map(vcat, Cs...) + C = map(vcat, Cs...) c = sum(cs) - H = _map(block_diagonal, Hs...) - h = _map(vcat, hs...) + H = map(block_diagonal, Hs...) + h = map(vcat, hs...) return C, c, H, h end Base.vcat(x::Zeros{T, 1}, y::Zeros{T, 1}) where {T} = Zeros{T}(length(x) + length(y)) -function block_diagonal(As::AbstractMatrix{T}...) where {T} - nblocks = length(As) - sizes = size.(As) - Xs = [i == j ? As[i] : Zeros{T}(sizes[j][1], sizes[i][2]) for i in 1:nblocks, j in 1:nblocks] - return hvcat(ntuple(_ -> nblocks, nblocks), Xs...) -end - -function ChainRulesCore.rrule(::typeof(block_diagonal), As::AbstractMatrix...) - szs = size.(As) - row_szs = (0, cumsum(first.(szs))...) - col_szs = (0, cumsum(last.(szs))...) - block_diagonal_rrule(Δ::AbstractThunk) = block_diagonal_rrule(unthunk(Δ)) - function block_diagonal_rrule(Δ) - ΔAs = ntuple(length(As)) do i - Δ[(row_szs[i]+1):row_szs[i+1], (col_szs[i]+1):col_szs[i+1]] - end - return NoTangent(), ΔAs... - end - return block_diagonal(As...), block_diagonal_rrule -end +block_diagonal(As::AbstractMatrix{T}...) where {T} = collect(BlockDiagonal(collect(As))) function block_diagonal(As::SMatrix...) - nblocks = length(As) - sizes = size.(As) - Xs = [i == j ? As[i] : zeros(SMatrix{sizes[j][1], sizes[i][2]}) for i in 1:nblocks, j in 1:nblocks] - return hcat(Base.splat(vcat).(eachrow(Xs))...) -end - -function ChainRulesCore.rrule(::typeof(block_diagonal), As::SMatrix...) - szs = size.(As) - row_szs = (0, cumsum(first.(szs))...) - col_szs = (0, cumsum(last.(szs))...) - function block_diagonal_rrule(Δ) - ΔAs = ntuple(length(As)) do i - Δ[SVector{szs[i][1]}((row_szs[i]+1):row_szs[i+1]), SVector{szs[i][2]}((col_szs[i]+1):col_szs[i+1])] - end - return NoTangent(), ΔAs... - end - return block_diagonal(As...), block_diagonal_rrule + M = block_diagonal(map(collect, As)...) + return SMatrix{sum(map(A -> size(A, 1), As)), sum(map(A -> size(A, 2), As))}(M) end diff --git a/src/models/gauss_markov_model.jl b/src/models/gauss_markov_model.jl index 4c2d62cf..95d9b65e 100644 --- a/src/models/gauss_markov_model.jl +++ b/src/models/gauss_markov_model.jl @@ -31,14 +31,6 @@ struct GaussMarkovModel{ x0::Tx0 end -# Helps Zygote out with some type-stability issues. Why this helps is unclear. -function ChainRulesCore.rrule(::Type{<:GaussMarkovModel}, ordering, As, as, Qs, x0) - function GaussMarkovModel_pullback(Δ) - return NoTangent(), NoTangent(), Δ.As, Δ.as, Δ.Qs, Δ.x0 - end - return GaussMarkovModel(ordering, As, as, Qs, x0), GaussMarkovModel_pullback -end - ordering(model::GaussMarkovModel) = model.ordering Base.eltype(model::GaussMarkovModel) = eltype(first(model.As)) @@ -66,27 +58,3 @@ function is_of_storage_type(model::GaussMarkovModel, s::StorageType) end x0(model::GaussMarkovModel) = model.x0 - -function get_adjoint_storage(x::GaussMarkovModel, n::Int, Δx::Tangent{T,<:NamedTuple{(:A, :a, :Q)}}) where {T} - return ( - ordering = NoTangent(), - As = get_adjoint_storage(x.As, n, Δx.A), - as = get_adjoint_storage(x.as, n, Δx.a), - Qs = get_adjoint_storage(x.Qs, n, Δx.Q), - x0 = NoTangent(), - ) -end - -function _accum_at( - Δxs::NamedTuple{(:ordering, :As, :as, :Qs, :x0)}, - n::Int, - Δx::Tangent{T, <:NamedTuple{(:A, :a, :Q)}}, -) where {T} - return ( - ordering = NoTangent(), - As = _accum_at(Δxs.As, n, Δx.A), - as = _accum_at(Δxs.as, n, Δx.a), - Qs = _accum_at(Δxs.Qs, n, Δx.Q), - x0 = NoTangent(), - ) -end diff --git a/src/models/lgssm.jl b/src/models/lgssm.jl index 9c211b77..86cabd45 100644 --- a/src/models/lgssm.jl +++ b/src/models/lgssm.jl @@ -16,7 +16,6 @@ end @inline emissions(model::LGSSM) = model.emissions @inline ordering(model::LGSSM) = ordering(transitions(model)) -ChainRulesCore.@non_differentiable ordering(model) function Base.:(==)(x::LGSSM, y::LGSSM) return (transitions(x) == transitions(y)) && (emissions(x) == emissions(y)) @@ -29,8 +28,6 @@ Base.eachindex(model::LGSSM) = eachindex(transitions(model)) storage_type(model::LGSSM) = storage_type(transitions(model)) -ChainRulesCore.@non_differentiable storage_type(x) - function is_of_storage_type(model::LGSSM, s::StorageType) return is_of_storage_type((transitions(model), emissions(model)), s) end @@ -210,8 +207,6 @@ function _check_inputs(prior, y) end end -ChainRulesCore.@non_differentiable _check_inputs(::Any, ::Any) - function _a_bit_of_posterior(prior, y) return scan_emit(step_posterior, zip(prior, y), x0(prior), eachindex(prior)) end @@ -252,30 +247,6 @@ ident_eps(ε::Real) = UniformScaling(ε) ident_eps(x::ColVecs, ε::Real) = UniformScaling(convert(eltype(x.X), ε)) -ChainRulesCore.@non_differentiable ident_eps(args...) - _collect(U::Adjoint{<:Any, <:Matrix}) = collect(U) _collect(U::SMatrix) = U _collect(U::BlockDiagonal) = U - -# AD stuff. No need to understand this unless you're really plumbing the depths... - -function get_adjoint_storage( - x::LGSSM, n::Int, Δx::Tangent{T,<:NamedTuple{(:ordering,:transition,:emission)}}, -) where {T} - return Tangent{typeof(x)}( - transitions = get_adjoint_storage(x.transitions, n, Δx.transition), - emissions = get_adjoint_storage(x.emissions, n, Δx.emission) - ) -end - -function _accum_at( - Δxs::Tangent{X}, - n::Int, - Δx::Tangent{T,<:NamedTuple{(:ordering,:transition,:emission)}}, -) where {X<:LGSSM, T} - return Tangent{X}( - transitions = _accum_at(Δxs.transitions, n, Δx.transition), - emissions = _accum_at(Δxs.emissions, n, Δx.emission), - ) -end diff --git a/src/models/linear_gaussian_conditionals.jl b/src/models/linear_gaussian_conditionals.jl index 62f65f69..a6c30fa4 100644 --- a/src/models/linear_gaussian_conditionals.jl +++ b/src/models/linear_gaussian_conditionals.jl @@ -97,13 +97,9 @@ function ε_randn(rng::AbstractRNG, ::SMatrix{Dout, Din, T}) where {Dout, Din, T return randn(rng, SVector{Dout, T}) end -ChainRulesCore.@non_differentiable ε_randn(args...) - scalar_type(::AbstractVector{T}) where {T} = T scalar_type(::T) where {T<:Real} = T -ChainRulesCore.@non_differentiable scalar_type(x) - """ SmallOutputLGC{ TA<:AbstractMatrix, Ta<:AbstractVector, TQ<:AbstractMatrix, @@ -172,16 +168,6 @@ struct LargeOutputLGC{ Q::TQ end -function ChainRulesCore.rrule( - ::Type{<:LargeOutputLGC}, - A::AbstractMatrix, - a::AbstractVector, - Q::AbstractMatrix, -) - LargeOutputLGC_pullback(Δ) = NoTangent(), Δ.A, Δ.a, Δ.Q - return LargeOutputLGC(A, a, Q), LargeOutputLGC_pullback -end - dim_out(f::LargeOutputLGC) = size(f.A, 1) dim_in(f::LargeOutputLGC) = size(f.A, 2) diff --git a/src/models/missings.jl b/src/models/missings.jl index 18ffb584..fbbd1e22 100644 --- a/src/models/missings.jl +++ b/src/models/missings.jl @@ -54,9 +54,6 @@ function _logpdf_volume_compensation(y::AbstractVector{<:Union{Missing, <:Real}} return count(ismissing, y) * log(2π * _large_var_const()) / 2 end - -ChainRulesCore.@non_differentiable _logpdf_volume_compensation(y) - function fill_in_missings(Σs::Vector, y::AbstractVector{Union{Missing, T}}) where {T} return _fill_in_missings(Σs, y) end @@ -91,38 +88,6 @@ end fill_in_missings(Σ::Diagonal, y::AbstractVector{<:Real}) = (Σ, y) -function ChainRulesCore.rrule( - ::typeof(_fill_in_missings), - Σs::Vector, - y::AbstractVector{Union{T, Missing}}, -) where {T} - function _fill_in_missings_rrule(Δ::Tangent) - ΔΣs, Δy_filled = Δ - - # The cotangent of a `Missing` doesn't make sense, so should be a `NoTangent`. - Δy = if Δy_filled isa AbstractZero - ZeroTangent() - else - Δy = Vector{Union{eltype(Δy_filled), ZeroTangent}}(undef, length(y)) - map!( - n -> y[n] === missing ? ZeroTangent() : Δy_filled[n], - Δy, eachindex(y), - ) - Δy - end - - # Fill in missing locations with zeros. Opting for type-stability to keep things - # simple. - ΔΣs = map( - n -> y[n] === missing ? zero(Σs[n]) : ΔΣs[n], - eachindex(y), - ) - - return NoTangent(), ΔΣs, Δy - end - return fill_in_missings(Σs, y), _fill_in_missings_rrule -end - get_zero(D::Int, ::Type{Vector{T}}) where {T} = zeros(T, D) get_zero(::Int, ::Type{T}) where {T<:SVector} = zeros(T) @@ -136,5 +101,3 @@ build_large_var(::T) where {T<:SMatrix} = T(_large_var_const() * I) build_large_var(S::T) where {T<:Diagonal} = T(fill(_large_var_const(), length(diag(S)))) build_large_var(::T) where {T<:Real} = T(_large_var_const()) - -ChainRulesCore.@non_differentiable build_large_var(::Any) diff --git a/src/space_time/pseudo_point.jl b/src/space_time/pseudo_point.jl index c5867df0..9020bf8f 100644 --- a/src/space_time/pseudo_point.jl +++ b/src/space_time/pseudo_point.jl @@ -54,9 +54,6 @@ function AbstractGPs.dtc(fx::FiniteLTISDE, y::AbstractVector, z_r::AbstractVecto return logpdf(dtcify(z_r, fx), y) end -# This stupid rule saves an absurb amount of compute time. -ChainRulesCore.@non_differentiable count(::typeof(ismissing), yn) - """ elbo(fx::FiniteLTISDE, y::AbstractVector{<:Real}, z_r::AbstractVector) @@ -91,8 +88,6 @@ function AbstractGPs.elbo(fx::FiniteLTISDE, y::AbstractVector, z_r::AbstractVect return logpdf(lgssm, y_vecs) - sum(tmp) / 2 end -Zygote.accum(x::NamedTuple{(:diag, )}, y::Diagonal) = Zygote.accum(x, (diag=y.diag, )) - function kernel_diagonals(k::DTCSeparable, x::RectilinearGrid) space_kernel = k.k.l time_kernel = k.k.r @@ -135,19 +130,19 @@ function lgssm_components(k_dtc::DTCSeparable, x::SpaceTimeGrid, storage::Storag # Get some size info. M = length(z_space) N = length(x_space) - ident_M = my_I(eltype(storage), M) + ident_M = Matrix{eltype(storage)}(I, M, M) # G is the time-invariant component of the H-matrices. It is only time-invariant because # we have the same obsevation locations at each point in time. Λu_Cuf = cholesky(Symmetric(K_space_z + 1e-12I)) \ K_space_zx # Construct approximately low-rank model spatio-temporal LGSSM. - As = _map(A -> kron(ident_M, A), As_t) - as = _map(a -> repeat(a, M), as_t) - Qs = _map(Q -> kron(K_space_z, Q), Qs_t) + As = map(A -> kron(ident_M, A), As_t) + as = map(a -> repeat(a, M), as_t) + Qs = map(Q -> kron(K_space_z, Q), Qs_t) Cs = Fill(Λu_Cuf, length(ts)) - cs = _map(h -> Fill(h, N), hs_t) # This should currently be zero. - Hs = _map(H -> kron(ident_M, H), Hs_t) + cs = map(h -> Fill(h, N), hs_t) # This should currently be zero. + Hs = map(H -> kron(ident_M, H), Hs_t) hs = Fill(Zeros(M), length(ts)) x0 = Gaussian(repeat(x0_t.m, M), kron(K_space_z, x0_t.P)) return As, as, Qs, (Cs, cs, Hs, hs), x0 @@ -171,22 +166,19 @@ function lgssm_components(k_dtc::DTCSeparable, x::RegularInTime, storage::Storag # Get some size info. M = length(z_space) N = length(ts) - ident_M = my_I(eltype(storage), M) + ident_M = Matrix{eltype(storage)}(I, M, M) # Construct approximately low-rank model spatio-temporal LGSSM. - As = _map(kron, Fill(ident_M, N), As_t) - as = _map(a -> repeat(a, M), as_t) - Qs = _map(kron, Fill(K_space_z, N), Qs_t) + As = map(kron, Fill(ident_M, N), As_t) + as = map(a -> repeat(a, M), as_t) + Qs = map(kron, Fill(K_space_z, N), Qs_t) x_big = _reduce(vcat, x.vs) C__ = kernelmatrix(space_kernel, z_space, x_big) C = \(K_space_z_chol, C__) - Cs = partition(ChainRulesCore.ignore_derivatives(map(length, x.vs)), C) + Cs = partition(map(length, x.vs), C) cs = fill.(hs_t, length.(x.vs)) # This should currently be zero. - Hs = _map( - ((I, H_t), ) -> kron(I, H_t), - zip(Fill(ident_M, N), Hs_t), - ) + Hs = map(((I, H_t), ) -> kron(I, H_t), zip(Fill(ident_M, N), Hs_t)) hs = Fill(Zeros(M), N) x0 = Gaussian(repeat(x0_t.m, M), kron(K_space_z, x0_t.P)) @@ -207,22 +199,12 @@ function partition(lengths::AbstractVector{<:Integer}, A::Matrix{<:Real}) return map((s, d) -> collect(view(A, :, s:s+d-1)), starts, lengths) end -function ChainRulesCore.rrule( - ::typeof(partition), - lengths::AbstractVector{<:Integer}, - A::Matrix{<:Real}, -) - partition_pullback(::NoTangent) = NoTangent(), NoTangent(), NoTangent() - partition_pullback(Δ::Vector) = NoTangent(), NoTangent(), reduce(hcat, Δ) - return partition(lengths, A), partition_pullback -end - function build_emissions( (Cs, cs, Hs, hs)::Tuple{AbstractVector, AbstractVector, AbstractVector, AbstractVector}, Σs::AbstractVector, ) - Hst = _map(adjoint, Hs) - Cst = _map(adjoint, Cs) + Hst = map(adjoint, Hs) + Cst = map(adjoint, Cs) fan_outs = StructArray{LargeOutputLGC{eltype(Cs), eltype(cs), eltype(Σs)}}((Cst, cs, Σs)) return StructArray{BottleneckLGC{eltype(Hst), eltype(hs), eltype(fan_outs)}}((Hst, hs, fan_outs)) end @@ -374,16 +356,16 @@ end function dtc_post_emissions(k::ScaledKernel, x_new::AbstractVector, storage::StorageType) (Cs, cs, Hs, hs), Σs = dtc_post_emissions(k.kernel, x_new, storage) σ = sqrt(convert(eltype(storage_type), only(k.σ²))) - return (Cs, cs, _map(H->σ * H, Hs), _map(h->σ * h, hs)), _map(Σ->σ^2 * Σ, Σs) + return (Cs, cs, map(H->σ * H, Hs), map(h->σ * h, hs)), map(Σ->σ^2 * Σ, Σs) end function dtc_post_emissions(k::KernelSum, x_new::AbstractVector, storage::StorageType) post_emissions = dtc_post_emissions.(k.kernels, Ref(x_new), Ref(storage)) Cs_cs_Hs_hs = getindex.(post_emissions, 1) Σs = getindex.(post_emissions, 2) - Cs = _map(vcat, getindex.(Cs_cs_Hs_hs, 1)...) + Cs = map(vcat, getindex.(Cs_cs_Hs_hs, 1)...) cs = sum(getindex.(Cs_cs_Hs_hs, 2)) - Hs = _map(block_diagonal, getindex.(Cs_cs_Hs_hs, 3)...) - hs = _map(vcat, getindex.(Cs_cs_Hs_hs, 4)...) + Hs = map(block_diagonal, getindex.(Cs_cs_Hs_hs, 3)...) + hs = map(vcat, getindex.(Cs_cs_Hs_hs, 4)...) return (Cs, cs, Hs, hs), sum(Σs) end diff --git a/src/space_time/rectilinear_grid.jl b/src/space_time/rectilinear_grid.jl index 786bb078..96e9506f 100644 --- a/src/space_time/rectilinear_grid.jl +++ b/src/space_time/rectilinear_grid.jl @@ -90,12 +90,7 @@ end # See docstring elsewhere for context. function noise_var_to_time_form(x::RectilinearGrid, S::Diagonal{<:Real}) - vs = restructure( - diag(S), - ChainRulesCore.ignore_derivatives() do - Fill(length(get_space(x)), length(get_times(x))) - end, - ) + vs = restructure(diag(S), Fill(length(get_space(x)), length(get_times(x)))) return zygote_friendly_map(v -> Diagonal(collect(v)), vs) end diff --git a/src/space_time/regular_in_time.jl b/src/space_time/regular_in_time.jl index 3c823118..452d8318 100644 --- a/src/space_time/regular_in_time.jl +++ b/src/space_time/regular_in_time.jl @@ -78,18 +78,11 @@ function restructure(y::AbstractVector{T}, lengths::AbstractVector{<:Integer}) w end end -function ChainRulesCore.rrule( - ::typeof(restructure), y::Vector, lengths::AbstractVector{<:Integer}, -) - restructure_pullback(Δ::Vector) = NoTangent(), reduce(vcat, Δ), NoTangent() - return restructure(y, lengths), restructure_pullback -end - # Implementation specific to Fills for AD's sake. function restructure(y::Fill{<:Real}, lengths::AbstractVector{<:Integer}) - return map(l -> Fill(y.value, l), ChainRulesCore.ignore_derivatives(lengths)) + return map(l -> Fill(y.value, l), lengths) end function restructure(y::AbstractVector, emissions::StructArray) - return restructure(y, ChainRulesCore.ignore_derivatives(map(dim_out, emissions))) + return restructure(y, map(dim_out, emissions)) end diff --git a/src/space_time/to_gauss_markov.jl b/src/space_time/to_gauss_markov.jl index 5b6afb72..78833215 100644 --- a/src/space_time/to_gauss_markov.jl +++ b/src/space_time/to_gauss_markov.jl @@ -1,7 +1,3 @@ -using ChainRulesCore -my_I(T, N) = Matrix{T}(I, N, N) -ChainRulesCore.@non_differentiable my_I(args...) - function lgssm_components(k::Separable, x::SpaceTimeGrid, storage) # Compute spatial covariance, and temporal GaussMarkovModel. @@ -14,17 +10,17 @@ function lgssm_components(k::Separable, x::SpaceTimeGrid, storage) # Compute components of complete LGSSM. Nr = length(r) - ident = my_I(eltype(storage), Nr) - As = _map(Base.Fix1(kron, ident), As_t) - as = _map(Base.Fix2(repeat, Nr), as_t) - Qs = _map(Base.Fix1(kron, Kr + ident_eps(1e-12)), Qs_t) + ident = Matrix{eltype(storage)}(I, Nr, Nr) + As = map(Base.Fix1(kron, ident), As_t) + as = map(Base.Fix2(repeat, Nr), as_t) + Qs = map(Base.Fix1(kron, Kr + ident_eps(1e-12)), Qs_t) emission_proj = _build_st_proj(emission_proj_t, Nr, ident) x0 = Gaussian(repeat(x0_t.m, Nr), kron(Kr, x0_t.P)) return As, as, Qs, emission_proj, x0 end function _build_st_proj((Hs, hs)::Tuple{AbstractVector, AbstractVector}, Nr::Integer, ident) - return (_map(H -> kron(ident, H), Hs), _map(h -> Fill(h, Nr), hs)) + return (map(H -> kron(ident, H), Hs), map(h -> Fill(h, Nr), hs)) end function build_prediction_obs_vars( diff --git a/src/util/chainrules.jl b/src/util/chainrules.jl deleted file mode 100644 index fe9e67eb..00000000 --- a/src/util/chainrules.jl +++ /dev/null @@ -1,409 +0,0 @@ -# This is all AD-related stuff. If you're looking to understand TemporalGPs, this can be -# safely ignored. - -using Zygote: accum, AContext -import ChainRulesCore: ProjectTo, rrule, _eltype_projectto - -# This context doesn't allow any globals. -struct NoContext <: Zygote.AContext end - -# Stupid implementation to obtain type-stability. -Zygote.cache(::NoContext) = (; cache_fields=nothing) - -# Stupid implementation. -Base.haskey(cx::NoContext, x) = false - -Zygote.accum_param(::NoContext, x, Δ) = Δ - -ChainRulesCore.@non_differentiable eltype(x) - -# Hacks to help the compiler out in very specific situations. -Zygote.accum(a::Array{T}, b::Array{T}) where {T<:Real} = a + b - -Zygote.accum(a::SArray{size, T}, b::SArray{size, T}) where {size, T<:Real} = a + b - -Zygote.accum(a::Tuple, b::Tuple, c::Tuple) = map(Zygote.accum, a, b, c) - -# ---------------------------------------------------------------------------- # -# StaticArrays # -# ---------------------------------------------------------------------------- # - -function rrule(::Type{T}, x::Tuple) where {T<:SArray} - SArray_rrule(Δ) = begin - (NoTangent(), Tangent{typeof(x)}(unthunk(Δ).data...)) - end - return T(x), SArray_rrule -end - -function rrule(::RuleConfig{>:HasReverseMode}, ::Type{SArray{S, T, N, L}}, x::NTuple{L, T}) where {S, T, N, L} - SArray_rrule(::AbstractZero) = NoTangent(), NoTangent() - SArray_rrule(Δ::NamedTuple{(:data,)}) = NoTangent(), Δ.data - SArray_rrule(Δ::StaticArray{S}) = NoTangent(), Δ.data - return SArray{S, T, N, L}(x), SArray_rrule -end - -function rrule( - config::RuleConfig{>:HasReverseMode}, ::Type{X}, x::NTuple{L, Any}, -) where {S, T, N, L, X <: SArray{S, T, N, L}} - new_x, convert_pb = rrule_via_ad(config, StaticArrays.convert_ntuple, T, x) - _, pb = rrule_via_ad(config, SArray{S, T, N, L}, new_x) - SArray_rrule(::AbstractZero) = NoTangent(), NoTangent() - SArray_rrule(Δ::SArray{S}) = SArray_rrule(Tangent{X}(data=Δ.data)) - SArray_rrule(Δ::SizedArray{S}) = SArray_rrule(Tangent{X}(data=Tuple(Δ.data))) - SArray_rrule(Δ::AbstractVector) = SArray_rrule(Tangent{X}(data=Tuple(Δ))) - SArray_rrule(Δ::Matrix) = SArray_rrule(Tangent{X}(data=Δ)) - function SArray_rrule(Δ::Tangent{X,<:NamedTuple{(:data,)}}) where {X} - _, Δnew_x = pb(backing(Δ)) - _, ΔT, Δx = convert_pb(Tuple(Δnew_x)) - return ΔT, Δx - end - return SArray{S, T, N, L}(x), SArray_rrule -end - -function rrule(::typeof(collect), x::X) where {S, T, N, L, X<:SArray{S, T, N, L}} - y = collect(x) - proj = ProjectTo(y) - collect_rrule(Δ) = NoTangent(), proj(Δ) - return y, collect_rrule -end - -function rrule(::typeof(vcat), A::SVector{DA}, B::SVector{DB}) where {DA, DB} - function vcat_rrule(Δ) # SVector - ΔA = Δ[SVector{DA}(1:DA)] - ΔB = Δ[SVector{DB}((DA+1):(DA+DB))] - return NoTangent(), ΔA, ΔB - end - return vcat(A, B), vcat_rrule -end - -@non_differentiable vcat(x::Zeros, y::Zeros) - -# Implementation of the matrix exponential that assumes one doesn't require access to the -# gradient w.r.t. `A`, only `t`. The former is a bit compute-intensive to get at, while the -# latter is very cheap. - -time_exp(A, t) = exp(A * t) -function rrule(::typeof(time_exp), A, t::Real) - B = exp(A * t) - time_exp_rrule(Ω̄) = NoTangent(), NoTangent(), sum(Ω̄ .* (A * B)) - return B, time_exp_rrule -end - - -# Following is taken from https://github.com/JuliaArrays/FillArrays.jl/pull/153 -# Until a solution has been found this code will be needed here. -""" - ProjectTo(::Fill) -> ProjectTo{Fill} - ProjectTo(::Ones) -> ProjectTo{NoTangent} - -Most FillArrays arrays store one number, and so their gradients under automatic -differentiation represent the variation of this one number. - -The exception is those like `Ones` and `Zeros` whose type fixes their value, -which have no graidient. -""" -ProjectTo(x::Fill) = ProjectTo{Fill}(; element = ProjectTo(FillArrays.getindex_value(x)), axes = axes(x)) - -ProjectTo(::AbstractFill{Bool}) = ProjectTo{NoTangent}() # Bool is always regarded as categorical - -ProjectTo(::Zeros) = ProjectTo{NoTangent}() -ProjectTo(::Ones) = ProjectTo{NoTangent}() - -(project::ProjectTo{Fill})(x::Fill) = x -function (project::ProjectTo{Fill})(dx::AbstractArray) - for d in 1:max(ndims(dx), length(project.axes)) - size(dx, d) == length(get(project.axes, d, 1)) || throw(_projection_mismatch(axes_x, size(dx))) - end - Fill(sum(dx), project.axes) -end - -function (project::ProjectTo{Fill})(dx::Tangent{<:Fill}) - # This would need a definition for length(::NoTangent) to be safe: - # for d in 1:max(length(dx.axes), length(project.axes)) - # length(get(dx.axes, d, 1)) == length(get(project.axes, d, 1)) || throw(_projection_mismatch(dx.axes, size(dx))) - # end - Fill(dx.value / prod(length, project.axes), project.axes) -end -function (project::ProjectTo{Fill})(dx::Tangent{Any,<:NamedTuple{(:value, :axes)}}) - Fill(dx.value / prod(length, project.axes), project.axes) -end - -# Yet another thing that should not happen -function Zygote.accum(x::Fill, y::NamedTuple{(:value, :axes)}) - Fill(x.value + y.value, x.axes) -end - -# We have an alternative map to avoid Zygote untouchable specialisation on map. -_map(f, args...) = map(f, args...) - -function rrule(::Type{<:Fill}, x, sz) - Fill_rrule(Δ::Union{Fill,Thunk}) = NoTangent(), FillArrays.getindex_value(unthunk(Δ)), NoTangent() - Fill_rrule(Δ::Tangent{T,<:NamedTuple{(:value, :axes)}}) where {T} = NoTangent(), Δ.value, NoTangent() - Fill_rrule(::AbstractZero) = NoTangent(), NoTangent(), NoTangent() - Fill_rrule(Δ::Tangent{T,<:NTuple}) where {T} = NoTangent(), sum(Δ), NoTangent() - function Fill_rrule(Δ::AbstractArray) - # all(==(first(Δ)), Δ) || error("Δ should be a vector of the same value") - # sum(Δ) - # TODO Fix this rule, or what seems to be a downstream bug. - return NoTangent(), sum(Δ), NoTangent() - end - Fill(x, sz), Fill_rrule -end - -function rrule(::typeof(Base.collect), x::Fill) - y = collect(x) - proj = ProjectTo(x) - function collect_Fill_rrule(Δ) - NoTangent(), proj(Δ) - end - return y, collect_Fill_rrule -end - - -function ChainRulesCore.rrule(config::RuleConfig{>:HasReverseMode}, ::typeof(_map), f, x::Fill) - y_el, back = ChainRulesCore.rrule_via_ad(config, f, x.value) - function _map_Fill_rrule(Δ::AbstractArray) - all(==(first(Δ)), Δ) || error("Δ should be a vector of the same value") - Δf, Δx_el = back(first(Δ)) - NoTangent(), Δf, Fill(Δx_el, axes(x)) - end - function _map_Fill_rrule(Δ::Union{Thunk,Fill,Tangent}) - Δf, Δx_el = back(unthunk(Δ).value) - return NoTangent(), Δf, Fill(Δx_el, axes(x)) - end - _map_Fill_rrule(::AbstractZero) = NoTangent(), NoTangent(), NoTangent() - return Fill(y_el, axes(x)), _map_Fill_rrule -end - -# Somehow needed to avoid the _map -> map indirection -function _map(f, xs::Fill...) - all(==(axes(first(xs))), axes.(xs)) || error("All axes should be the same") - Fill(f(FillArrays.getindex_value.(xs)...), axes(first(xs))) -end - -function ChainRulesCore.rrule(config::RuleConfig{>:HasReverseMode}, ::typeof(_map), f, xs::Fill...) - z_el, back = ChainRulesCore.rrule_via_ad(config, f, FillArrays.getindex_value.(xs)...) - function _map_Fill_rrule(Δ) - Δf, Δxs_el... = back(unthunk(Δ).value) - return NoTangent(), Δf, Fill.(Δxs_el, axes.(xs))... - end - return Fill(z_el, axes(first(xs))), _map_Fill_rrule -end -### Same thing for `StructArray` - - -function rrule(::typeof(step), x::T) where {T<:StepRangeLen} - function step_StepRangeLen_rrule(Δ) - return NoTangent(), Tangent{T}(step=Δ) - end - return step(x), step_StepRangeLen_rrule -end - -function rrule(::typeof(Base.getindex), x::SVector{1,1}, n::Int) - getindex_SArray_rrule(Δ) = NoTangent(), SVector{1}(Δ), NoTangent() - return x[n], getindex_SArray_rrule -end - -# -# AD-free pullbacks for a few things. These are primitives that will be used to write the -# gradients. -# - -function cholesky_rrule(Σ::Symmetric{<:Real, <:StridedMatrix}) - C = cholesky(Σ) - function cholesky_pullback(Δ::NamedTuple) - U, Ū = C.U, Δ.factors - Σ̄ = Ū * U' - Σ̄ = LinearAlgebra.copytri!(Σ̄, 'U') - Σ̄ = ldiv!(U, Σ̄) - BLAS.trsm!('R', 'U', 'T', 'N', one(eltype(Σ)), U.data, Σ̄) - - for n in diagind(Σ̄) - Σ̄[n] /= 2 - end - return NoTangent(), UpperTriangular(Σ̄) - end - return C, cholesky_pullback -end - -function cholesky_rrule(S::Symmetric{<:Real, <:StaticMatrix{N, N}}) where {N} - C = cholesky(S) - function cholesky_pullback(Δ::Tangent) - U, Ū = C.U, Δ.factors - Σ̄ = SMatrix{N,N}(Symmetric(Ū * U')) - Σ̄ = U \ (U \ Σ̄)' - Σ̄ = Σ̄ - Diagonal(Σ̄) / 2 - return NoTangent(), Tangent{typeof(S)}(data=SMatrix{N, N}(UpperTriangular(Σ̄))) - end - return C, cholesky_pullback -end - -function rrule(::typeof(cholesky), S::Symmetric{<:Real, <:StaticMatrix{N, N}}) where {N} - return cholesky_rrule(S) -end - -function Zygote.accum(a::UpperTriangular, b::UpperTriangular) - return UpperTriangular(Zygote.accum(a.data, b.data)) -end - -Zygote.accum(D::Diagonal{<:Real}, U::UpperTriangular{<:Real}) = UpperTriangular(D + U.data) -Zygote.accum(a::UpperTriangular, b::Diagonal) = Zygote.accum(b, a) - -Zygote._symmetric_back(Δ::UpperTriangular{<:Any, <:SArray}, uplo) = Δ -function Zygote._symmetric_back(Δ::SMatrix{N, N}, uplo) where {N} - if uplo === 'U' - return SMatrix{N, N}(UpperTriangular(Δ) + UpperTriangular(Δ') - Diagonal(Δ)) - else - return SMatrix{N, N}(LowerTriangular(Δ) + LowerTriangular(Δ') - Diagonal(Δ)) - end -end - -# Temporary hacks. - -using Zygote: literal_getproperty, literal_indexed_iterate, literal_getindex - -function Zygote._pullback(::NoContext, ::typeof(*), A::Adjoint, B::AbstractMatrix) - times_pullback(::Nothing) = nothing - times_pullback(Δ) = nothing, Adjoint(B * Δ'), A' * Δ - return A * B, times_pullback -end - -function Zygote._pullback(::NoContext, ::typeof(literal_getproperty), C::Cholesky, ::Val{:U}) - function literal_getproperty_pullback(Δ) - return (nothing, (uplo=nothing, info=nothing, factors=UpperTriangular(Δ))) - end - literal_getproperty_pullback(Δ::Nothing) = nothing - return literal_getproperty(C, Val(:U)), literal_getproperty_pullback -end - -Zygote.accum(x::Adjoint...) = Adjoint(Zygote.accum(map(parent, x)...)) - -Zygote.accum(x::NamedTuple{(:parent,)}, y::Adjoint) = (parent=accum(x.parent, y.parent),) - -function Zygote.accum(A::UpperTriangular{<:Any, <:SMatrix{P}}, B::SMatrix{P, P}) where {P} - return Zygote.accum(SMatrix{P, P}(A), B) -end - -function Zygote.accum(B::SMatrix{P, P}, A::UpperTriangular{<:Any, <:SMatrix{P}}) where {P} - return Zygote.accum(B, SMatrix{P, P}(A)) -end - -function Zygote.accum(a::Tangent{T}, b::NamedTuple) where {T} - return Zygote.accum(a, Tangent{T}(; b...)) -end - -function Base.:(-)( - A::UpperTriangular{<:Real, <:SMatrix{N, N}}, B::Diagonal{<:Real, <:SVector{N}}, -) where {N} - return UpperTriangular(A.data - B) -end - -function _symmetric_back(Δ, uplo) - L, U, D = LowerTriangular(Δ), UpperTriangular(Δ), Diagonal(Δ) - return collect(uplo == Symbol(:U) ? U .+ transpose(L) - D : L .+ transpose(U) - D) -end -_symmetric_back(Δ::Diagonal, uplo) = Δ -_symmetric_back(Δ::UpperTriangular, uplo) = collect(uplo == Symbol('U') ? Δ : transpose(Δ)) -_symmetric_back(Δ::LowerTriangular, uplo) = collect(uplo == Symbol('U') ? transpose(Δ) : Δ) - -function ChainRulesCore.rrule(::Type{Symmetric}, X::StridedMatrix{<:Real}, uplo=:U) - function Symmetric_rrule(Δ) - ΔX = Δ isa AbstractZero ? NoTangent() : _symmetric_back(Δ, uplo) - return NoTangent(), ΔX, NoTangent() - end - return Symmetric(X, uplo), Symmetric_rrule -end - -function rrule(::Type{StructArray}, x::T) where {T<:Union{Tuple,NamedTuple}} - y = StructArray(x) - StructArray_rrule(Δ::Thunk) = StructArray_rrule(unthunk(Δ)) - function StructArray_rrule(Δ) - return NoTangent(), Tangent{T}(StructArrays.components(backing.(Δ))...) - end - function StructArray_rrule(Δ::AbstractArray) - return NoTangent(), Tangent{T}((getproperty.(Δ, p) for p in propertynames(y))...) - end - return y, StructArray_rrule -end -function rrule(::Type{StructArray{X}}, x::T) where {X,T<:Union{Tuple,NamedTuple}} - y = StructArray{X}(x) - function StructArray_rrule(Δ) - return NoTangent(), Tangent{T}(StructArrays.components(backing.(Δ))...) - end - function StructArray_rrule(Δ::Tangent) - return NoTangent(), Tangent{T}(Δ.components...) - end - return y, StructArray_rrule -end - - -# `getproperty` accesses the `components` field of a `StructArray`. This rule makes that -# explicit. -# function rrule(config::RuleConfig{>:HasReverseMode}, ::typeof(Base.getproperty), x::StructArray, ::Val{p}, -# ) where {p} -# value, pb = rrule_via_ad(config, Base.getproperty, StructArrays.components(x), Val(p)) -# function getproperty_rrule(Δ) -# return NoTangent(), Tangent{typeof(x)}(components=pb(Δ)[2]), NoTangent() -# end -# return value, getproperty_rrule -# end - -function time_ad(label::String, f, x...) - println("primal: ", label) - return @time f(x...) -end - -time_ad(::Val{:disabled}, label::String, f, x...) = f(x...) - -function ChainRulesCore.rrule(config::RuleConfig{>:HasReverseMode}, ::typeof(time_ad), label::String, f, x...) - println("Forward: ", label) - out, pb = @time rrule_via_ad(config, f, x...) - function time_ad_pullback(Δ) - println("Pullback: ", label) - Δinputs = @time pb(Δ) - return (NoTangent(), NoTangent(), NoTangent(), Δinputs...) - end - return out, time_ad_pullback -end - -function ChainRulesCore.rrule(config::RuleConfig{>:HasReverseMode}, ::typeof(\), A::Diagonal{<:Real}, x::Vector{<:Real}) - out, pb = rrule_via_ad(config, (a, x) -> a .\ x, diag(A), x) - function ldiv_pullback(Δ) - if Δ isa AbstractZero - return NoTangent() - else - _, Δa, Δx = pb(Δ) - return NoTangent(), Diagonal(Δa), Δx - end - end - return out, ldiv_pullback -end - -function ChainRulesCore.rrule(config::RuleConfig{>:HasReverseMode}, ::typeof(\), A::Diagonal{<:Real}, x::Matrix{<:Real}) - out, pb = rrule_via_ad(config, (a, x) -> a .\ x, diag(A), x) - function ldiv_pullback(Δ) - if Δ isa AbstractZero - return NoTangent() - else - _, Δa, Δx = pb(Δ) - return NoTangent(), Diagonal(Δa), Δx - end - end - return out, ldiv_pullback -end - -using Base.Broadcast: broadcasted - -function ChainRulesCore.rrule(::typeof(broadcasted), ::typeof(\), a::Vector{<:Real}, x::Vector{<:Real}) - y = a .\ x - broadcast_ldiv_pullback(::AbstractZero) = NoTangent(), NoTangent(), NoTangent() - broadcast_ldiv_pullback(Δ::AbstractVector{<:Real}) = NoTangent(), NoTangent(), -(Δ .* y ./ a), a .\ Δ - return y, broadcast_ldiv_pullback -end - -function ChainRulesCore.rrule(::typeof(broadcasted), ::typeof(\), a::Vector{<:Real}, x::Matrix{<:Real}) - y = a .\ x - broadcast_ldiv_pullback(::AbstractZero) = NoTangent(), NoTangent(), NoTangent() - broadcast_ldiv_pullback(Δ::AbstractMatrix{<:Real}) = NoTangent(), NoTangent(), -vec(sum(Δ .* y ./ a; dims=2)), a .\ Δ - return y, broadcast_ldiv_pullback -end diff --git a/src/util/gaussian.jl b/src/util/gaussian.jl index 6f5bf00c..5f9667db 100644 --- a/src/util/gaussian.jl +++ b/src/util/gaussian.jl @@ -70,13 +70,6 @@ storage_type(::Gaussian{<:Vector{T}}) where {T<:Real} = ArrayStorage(T) storage_type(::Gaussian{<:SVector{D, T}}) where {D, T<:Real} = SArrayStorage(T) storage_type(::Gaussian{T}) where {T<:Real} = ScalarStorage(T) -function ChainRulesCore.rrule(::Type{<:Gaussian}, m, P) - proj_P = ProjectTo(P) - Gaussian_pullback(::ZeroTangent) = NoTangent(), NoTangent(), NoTangent() - Gaussian_pullback(Δ) = NoTangent(), Δ.m, proj_P(Δ.P) - return Gaussian(m, P), Gaussian_pullback -end - Base.length(x::Gaussian) = 0 # Zero-adjoint initialisation for the benefit of `scan`. diff --git a/src/util/harmonise.jl b/src/util/harmonise.jl deleted file mode 100644 index 5989d890..00000000 --- a/src/util/harmonise.jl +++ /dev/null @@ -1,122 +0,0 @@ -# All of this functionality is utilised only in the AD tests. Can be safely ignored if -# you're concerned with understanding how TemporalGPs works. - -using ChainRulesCore: backing - -# Functionality to test my testing functionality. -are_harmonised(a::Any, b::AbstractZero) = true -are_harmonised(a::AbstractZero, b::Any) = true -are_harmonised(a::AbstractZero, b::AbstractZero) = true - -are_harmonised(a::Number, b::Number) = true - -function are_harmonised(a::AbstractArray, b::AbstractArray) - return all(ab -> are_harmonised(ab...), zip(a, b)) -end - -are_harmonised(a::Tuple, b::Tuple) = all(ab -> are_harmonised(ab...), zip(a, b)) - -function are_harmonised(a::Tangent{<:Any, <:Tuple}, b::Tangent{<:Any, <:Tuple}) - return all(ab -> are_harmonised(ab...), zip(a, b)) -end - -function are_harmonised( - a::Tangent{<:Any, <:NamedTuple}, - b::Tangent{<:Any, <:NamedTuple}, -) - return all( - name -> are_harmonised(getproperty(a, name), getproperty(b, name)), - union(fieldnames(typeof(a)), fieldnames(typeof(b))), - ) -end - -# Functionality to make it possible to compare different kinds of differentials. It's not -# entirely clear how much sense this makes mathematically, but it seems to work in a -# practical sense at the minute. -harmonise(a::Any, b::AbstractZero) = (a, b) -harmonise(a::AbstractZero, b::Any) = (a, b) -harmonise(a::AbstractZero, b::AbstractZero) = (a, b) - -# Resolve ambiguity. -harmonise(a::AbstractZero, b::Tangent{<:Any, <:NamedTuple}) = (a, b) - -harmonise(a::Number, b::Number) = (a, b) - -function harmonise(a::Tuple, b::Tuple) - vals = map(harmonise, a, b) - return first.(vals), last.(vals) -end -function harmonise(a::AbstractArray, b::AbstractArray) - vals = map(harmonise, a, b) - return first.(vals), last.(vals) -end - -function harmonise(a::Adjoint, b::Adjoint) - vals = harmonise(a.parent, b.parent) - return Tangent{Any}(parent=vals[1]), Tangent{Any}(parent=vals[2]) -end - -function harmonise(a::Tangent{<:Any, <:Tuple}, b::Tangent{<:Any, <:Tuple}) - vals = map(harmonise, backing(a), backing(b)) - return (Tangent{Any}(first.(vals)...), Tangent{Any}(last.(vals)...)) -end - -harmonise(a::Tangent{<:Any, <:Tuple}, b::Tuple) = harmonise(a, Tangent{Any}(b...)) - -harmonise(a::Tuple, b::Tangent{<:Any, <:Tuple}) = harmonise(Tangent{Any}(a...), b) - -function harmonise( - a::Tangent{<:Any, <:NamedTuple{names}}, - b::Tangent{<:Any, <:NamedTuple{names}}, -) where {names} - vals = map(harmonise, values(backing(a)), values(backing(b))) - a_harmonised = Tangent{Any}(; NamedTuple{names}(first.(vals))...) - b_harmonised = Tangent{Any}(; NamedTuple{names}(last.(vals))...) - return (a_harmonised, b_harmonised) -end - -function harmonise(a::Tangent{<:Any, <:NamedTuple}, b::Tangent{<:Any, <:NamedTuple}) - - # Compute names missing / present in each data structure. - a_names = propertynames(backing(a)) - b_names = propertynames(backing(b)) - mutual_names = intersect(a_names, b_names) - all_names = (union(a_names, b_names)..., ) - a_missing_names = setdiff(all_names, a_names) - b_missing_names = setdiff(all_names, b_names) - - # Construct `Tangent`s with the same names. - a_vals = map(name -> name ∈ a_names ? getproperty(a, name) : ZeroTangent(), all_names) - b_vals = map(name -> name ∈ b_names ? getproperty(b, name) : ZeroTangent(), all_names) - a_unioned_names = Tangent{Any}(; NamedTuple{all_names}(a_vals)...) - b_unioned_names = Tangent{Any}(; NamedTuple{all_names}(b_vals)...) - - # Harmonise those composites. - return harmonise(a_unioned_names, b_unioned_names) -end - -function harmonise(a::Tangent{<:Any, <:NamedTuple}, b) - b_names = fieldnames(typeof(b)) - vals = map(name -> getfield(b, name), b_names) - return harmonise( - a, Tangent{Any}(; NamedTuple{b_names}(vals)...), - ) -end - -harmonise(x::AbstractMatrix, y::NamedTuple{(:diag,)}) = (diag(x), y.diag) -function harmonise(x::AbstractVector, y::NamedTuple{(:value,:axes)}) - x = reduce(Zygote.accum, x) - (x, y.value) -end - - -harmonise(a::Tangent{<:Any, <:NamedTuple}, b::AbstractZero) = (a, b) - -harmonise(a, b::Tangent{<:Any, <:NamedTuple}) = reverse(harmonise(b, a)) - -# Special-cased handling for `Adjoint`s. Due to our usual AD setup, a differential for an -# Adjoint can be represented either by a matrix or a `Tangent`. Both ought to `to_vec` to -# the same thing though, so this should be fine for now, if a little unsatisfactory. -function harmonise(a::Adjoint, b::Tangent{<:Adjoint, <:NamedTuple}) - return Tangent{Any}(parent=parent(a)), b -end diff --git a/src/util/regular_data.jl b/src/util/regular_data.jl index d9c59ffc..ee962acd 100644 --- a/src/util/regular_data.jl +++ b/src/util/regular_data.jl @@ -2,11 +2,8 @@ RegularSpacing{T<:Real} <: AbstractVector{T} `RegularSpacing(t0, Δt, N)` represents the same thing as `range(t0; step=Δt, length=N)`, but -has a different implementation that makes it possible to differentiate through with the -current version of `Zygote`. This data structure will be entirely removed once it's possible -to work with `StepRangeLen`s in `Zygote`. - -Relevant issue: https://github.com/FluxML/Zygote.jl/issues/550 +has a different implementation which avoids using extended-precision floating point +numbers. This is needed for all AD frameworks. """ struct RegularSpacing{T<:Real} <: AbstractVector{T} t0::T @@ -23,11 +20,3 @@ Base.size(x::RegularSpacing) = (x.N,) Base.getindex(x::RegularSpacing, n::Int) = x.t0 + (n - 1) * x.Δt Base.step(x::RegularSpacing) = x.Δt - -function ChainRulesCore.rrule(::Type{TR}, t0::T, Δt::T, N::Int) where {TR<:RegularSpacing, T<:Real} - function RegularSpacing_rrule(Δ) - Δ = unthunk(Δ) - return NoTangent(), Δ.t0, Δ.Δt, NoTangent() - end - return RegularSpacing(t0, Δt, N), RegularSpacing_rrule -end diff --git a/src/util/scan.jl b/src/util/scan.jl index 8ce67db4..4a42c30b 100644 --- a/src/util/scan.jl +++ b/src/util/scan.jl @@ -27,68 +27,6 @@ function scan_emit(f, xs, state, idx) return (ys, state) end -function rrule(config::RuleConfig, ::typeof(scan_emit), f, xs, init_state, idx) - state = init_state - (y, state) = f(state, _getindex(xs, idx[1])) - - # Heuristic Warning: assume all ys and states have the same type as the 1st. - ys = Vector{typeof(y)}(undef, length(xs)) - states = Vector{typeof(state)}(undef, length(xs)) - - ys[idx[1]] = y - states[idx[1]] = state - - for t in idx[2:end] - (y, state) = f(state, _getindex(xs, t)) - ys[t] = y - states[t] = state - end - - function scan_emit_rrule(Δ) - Δ isa AbstractZero && return ntuple(_->NoTangent(), 5) - Δys = Δ[1] - Δstate = Δ[2] - - # This is a hack to handle the case that Δstate=nothing, and the "look at the - # type of the first thing" heuristic breaks down. - Δstate = Δ[2] isa AbstractZero ? _get_zero_adjoint(states[idx[end]]) : Δ[2] - - T = length(idx) - if T > 1 - _, Δstate, Δx = step_pullback( - config, f, states[idx[T-1]], _getindex(xs, idx[T]), Δys[idx[T]], Δstate, - ) - Δxs = get_adjoint_storage(xs, idx[T], Δx) - for t in reverse(2:(T - 1)) - a = _getindex(xs, idx[t]) - b = Δys[idx[t]] - c = states[idx[t-1]] - _, Δstate, Δx = step_pullback( - config, f, c, a, b, Δstate, - ) - Δxs = _accum_at(Δxs, idx[t], Δx) - end - _, Δstate, Δx = step_pullback( - config, f, init_state, _getindex(xs, idx[1]), Δys[idx[1]], Δstate, - ) - Δxs = _accum_at(Δxs, idx[1], Δx) - return NoTangent(), NoTangent(), Δxs, Δstate, NoTangent() - else - _, Δstate, Δx = step_pullback( - config, f, init_state, _getindex(xs, idx[1]), Δys[idx[1]], Δstate, - ) - Δxs = get_adjoint_storage(xs, idx[1], Δx) - return NoTangent(), NoTangent(), Δxs, Δstate, NoTangent() - end - end - return (ys, state), scan_emit_rrule -end - -@inline function step_pullback(config::RuleConfig, f::Tf, state, x, Δy, Δstate) where {Tf} - _, pb = rrule_via_ad(config, f, state, x) - return pb((Δy, Δstate)) -end - # Helper functionality for constructing appropriate differentials. _getindex(x, idx::Int) = getindex(x, idx) @@ -100,78 +38,3 @@ _getindex(x, idx::Int) = getindex(x, idx) _getindex(x::Base.Iterators.Zip, idx::Int) = __getindex(x.is, idx) __getindex(x::Tuple{Any}, idx::Int) = (_getindex(x[1], idx), ) __getindex(x::Tuple, idx::Int) = (_getindex(x[1], idx), __getindex(Base.tail(x), idx)...) - - -_get_zero_adjoint(::Any) = ZeroTangent() - -# Vector. In all probability, only one of these methods is necessary. - -function get_adjoint_storage(x::Array, n::Int, Δx::T) where {T} - x̄ = Array{T}(undef, size(x)) - x̄[n] = Δx - return x̄ -end - -@inline function _accum_at(Δxs::Vector{T}, n::Int, Δx::T) where {T} - Δxs[n] = Δx - return Δxs -end - -@inline function _accum_at(Δxs::Vector{T}, n::Int, Δx::AbstractMatrix) where {T<:AbstractMatrix} - Δxs[n] = convert(T, Δx) - return Δxs -end - -# If there's nothing, there's nothing to do. -_accum_at(::AbstractZero, ::Int, ::AbstractZero) = NoTangent() - -# Zip -function get_adjoint_storage(x::Base.Iterators.Zip, n::Int, Δx::Tangent) - return (is=map((x_, Δx_) -> get_adjoint_storage(x_, n, Δx_), x.is, backing(Δx)),) -end - -# This is a work-around for `map` not inferring for some unknown reason. Very odd... -function _accum_at(Δxs::NamedTuple{(:is, )}, n::Int, Δx::Tangent) - return (is=__accum_at(Δxs.is, n, backing(Δx)), ) -end -__accum_at(Δxs::Tuple{Any}, n::Int, Δx::Tuple{Any}) = (_accum_at(Δxs[1], n, Δx[1]), ) -function __accum_at(Δxs::Tuple, n::Int, Δx::Tuple) - return (_accum_at(Δxs[1], n, Δx[1]), __accum_at(Base.tail(Δxs), n, Base.tail(Δx))...) -end -# Fill - -get_adjoint_storage(::Fill, ::Int, init) = (value=init, axes=NoTangent()) - -# T is not parametrized since T can be SMatrix and Δx isa SizedMatrix -@inline function _accum_at( - Δxs::NamedTuple{(:value, :axes)}, ::Int, Δx, -) - return (value=Zygote.accum(Δxs.value, Δx), axes=NoTangent()) -end - - - -# StructArray - -function get_adjoint_storage(x::StructArray, n::Int, Δx::Tangent) - init_arrays = map( - (x_, Δx_) -> get_adjoint_storage(x_, n, Δx_), getfield(x, :components), ChainRulesCore.backing(Δx), - ) - return (components = init_arrays, ) -end - -function get_adjoint_storage(x::StructArray, n::Int, Δx::StaticVector) - init_arrays = map( - (x_, Δx_) -> get_adjoint_storage(x_, n, Δx_), getfield(x, :components), Δx, - ) - return (components = init_arrays, ) -end - -# _accum_at for StructArrayget_adjoint_storage(xs, idx[T], Δx) -function _accum_at(Δxs::NamedTuple{(:components,)}, n::Int, Δx::Tangent) - return (components = map((Δy, y) -> _accum_at(Δy, n, y), Δxs.components, backing(Δx)), ) -end - -function _accum_at(Δxs::NamedTuple{(:components,)}, n::Int, Δx::SVector) - return (components = map((Δy, y) -> _accum_at(Δy, n, y), Δxs.components, backing(Δx)), ) -end diff --git a/src/util/zygote_friendly_map.jl b/src/util/zygote_friendly_map.jl index ab0eba29..a9440942 100644 --- a/src/util/zygote_friendly_map.jl +++ b/src/util/zygote_friendly_map.jl @@ -30,45 +30,6 @@ function dense_zygote_friendly_map(f::Tf, x) where {Tf} return ys end -function ChainRulesCore.rrule(::typeof(dense_zygote_friendly_map), f::Tf, x) where {Tf} - - # Perform first iteration. - y_1, pb_1 = rrule_via_ad(Zygote.ZygoteRuleConfig(NoContext()), f, _getindex(x, 1)) - - # Allocate for outputs. - ys = Array{typeof(y_1)}(undef, size(x)) - ys[1] = y_1 - - # Allocate for pullbacks. - pbs = Array{typeof(pb_1)}(undef, size(x)) - pbs[1] = pb_1 - - for n in 2:length(x) - y, pb = rrule_via_ad(Zygote.ZygoteRuleConfig(NoContext()), f, _getindex(x, n)) - ys[n] = y - pbs[n] = pb - end - - function zygote_friendly_map_pullback(Δ) - Δ isa AbstractZero && return NoTangent(), NoTangent(), NoTangent() - - # Do first iteration. - Δx_1 = pbs[1](Δ[1]) - - # Allocate for cotangents. - Δxs = get_adjoint_storage(x, 1, Δx_1[2]) - - for n in 2:length(x) - Δx = pbs[n](Δ[n]) - Δxs = _accum_at(Δxs, n, Δx[2]) - end - - return NoTangent(), NoTangent(), Δxs - end - - return ys, zygote_friendly_map_pullback -end - zygote_friendly_map(f, x::Fill) = map(f, x) function zygote_friendly_map( diff --git a/test/Project.toml b/test/Project.toml deleted file mode 100644 index d3eabb40..00000000 --- a/test/Project.toml +++ /dev/null @@ -1,29 +0,0 @@ -[deps] -AbstractGPs = "99985d1d-32ba-4be9-9821-2ec096f28918" -BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" -BlockDiagonals = "0a1fb500-61f7-11e9-3c65-f5ef3456f9f0" -ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" -ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" -FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" -FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" -KernelFunctions = "ec8451be-7e33-11e9-00cf-bbf324bd1392" -LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" -Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" -Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" -StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" -StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a" -Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" -Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" - -[compat] -AbstractGPs = "0.5" -BenchmarkTools = "0.5" -BlockDiagonals = "0.1" -ChainRulesCore = "1" -ChainRulesTestUtils = "1.10" -FillArrays = "0.13.0 - 0.13.7" -FiniteDifferences = "0.12" -KernelFunctions = "0.10" -StaticArrays = "1" -StructArrays = "0.6" -Zygote = "0.6" diff --git a/test/front_matter.jl b/test/front_matter.jl new file mode 100644 index 00000000..313baa40 --- /dev/null +++ b/test/front_matter.jl @@ -0,0 +1,39 @@ + +using AbstractGPs, + BlockDiagonals, + FillArrays, + LinearAlgebra, + KernelFunctions, + Mooncake, + Random, + StaticArrays, + StructArrays, + TemporalGPs, + Test + +using AbstractGPs: var +using Mooncake.TestUtils: test_rule +using TemporalGPs: + AbstractLGSSM, + _filter, + Gaussian, + x0, + fill_in_missings, + replace_observation_noise_cov, + scan_emit, + transform_model_and_obs + +ENV["TESTING"] = "TRUE" + +# GROUP is an env variable from CI which can take the following values +# ["test util", "test models" "test models-lgssm" "test gp" "test space_time"] +# Select any of this to test a particular aspect. +# To test everything, simply set GROUP to "all" +# ENV["GROUP"] = "test gp" +const GROUP = get(ENV, "GROUP", "all") + +const TEST_TYPE_INFER = false # Test type stability over the tests +const TEST_ALLOC = false # Test allocations over the tests + +include("test_util.jl") +include(joinpath("models", "model_test_utils.jl")) diff --git a/test/gp/lti_sde.jl b/test/gp/lti_sde.jl index a4f20578..4d07965e 100644 --- a/test/gp/lti_sde.jl +++ b/test/gp/lti_sde.jl @@ -1,6 +1,5 @@ using KernelFunctions using KernelFunctions: kappa -using ChainRulesTestUtils using TemporalGPs: build_lgssm, StorageType, is_of_storage_type, lgssm_components using Test @@ -14,47 +13,50 @@ function _construction_tester(f_naive::GP, storage::StorageType, σ², t::Abstra return build_lgssm(fx) end -@testset "ApproxPeriodicKernel" begin - k = ApproxPeriodicKernel() - @test k isa ApproxPeriodicKernel{7} - # Test that it behaves like a normal PeriodicKernel - k_base = PeriodicKernel() - x = rand() - @test kappa(k, x) == kappa(k_base, x) - x = rand(3) - @test kernelmatrix(k, x) ≈ kernelmatrix(k_base, x) - # Test dimensionality of LGSSM components - Nt = 10 - @testset "$(typeof(t)), $storage, $N" for t in ( - sort(rand(Nt)), RegularSpacing(0.0, 0.1, Nt) - ), - storage in (ArrayStorage{Float64}(), ), - N in (5, 8) - - k = ApproxPeriodicKernel{N}() - As, as, Qs, emission_projections, x0 = lgssm_components(k, t, storage) - @test length(As) == Nt - @test all(x -> size(x) == (N * 2, N * 2), As) - @test length(as) == Nt - @test all(x -> size(x) == (N * 2,), as) - @test length(Qs) == Nt - @test all(x -> size(x) == (N * 2, N * 2), Qs) - end -end - println("lti_sde:") @testset "lti_sde" begin + + @testset "ApproxPeriodicKernel" begin + k = ApproxPeriodicKernel() + @test k isa ApproxPeriodicKernel{7} + # Test that it behaves like a normal PeriodicKernel + k_base = PeriodicKernel() + x = rand() + @test kappa(k, x) == kappa(k_base, x) + x = rand(3) + @test kernelmatrix(k, x) ≈ kernelmatrix(k_base, x) + # Test dimensionality of LGSSM components + Nt = 10 + @testset "$(typeof(t)), $storage, $N" for t in ( + sort(rand(Nt)), RegularSpacing(0.0, 0.1, Nt) + ), + storage in (ArrayStorage{Float64}(), ), + N in (5, 8) + + k = ApproxPeriodicKernel{N}() + As, as, Qs, emission_projections, x0 = lgssm_components(k, t, storage) + @test length(As) == Nt + @test all(x -> size(x) == (N * 2, N * 2), As) + @test length(as) == Nt + @test all(x -> size(x) == (N * 2,), as) + @test length(Qs) == Nt + @test all(x -> size(x) == (N * 2, N * 2), Qs) + end + end + @testset "block_diagonal" begin + rng = Xoshiro(123) A = randn(2, 2) B = randn(3, 3) C = randn(5, 5) - test_rrule(TemporalGPs.block_diagonal, A, B, C; check_inferred=false) - test_rrule( + test_rule(rng, TemporalGPs.block_diagonal, A, B, C; is_primitive=false) + test_rule( + rng, TemporalGPs.block_diagonal, SMatrix{2,2}(A), SMatrix{3,3}(B), SMatrix{5,5}(C); - check_inferred=false, + is_primitive=false, ) end @@ -209,47 +211,10 @@ println("lti_sde:") @test logpdf(fx, y) ≈ logpdf(fx_naive, y) end - @testset "check args to_vec properly" begin - k_vec, k_from_vec = to_vec(kernel.val) - @test typeof(k_from_vec(k_vec)) == typeof(kernel.val) - - storage_vec, storage_from_vec = to_vec(storage.val) - @test typeof(storage_from_vec(storage_vec)) == typeof(storage.val) - - σ²_vec, σ²_from_vec = to_vec(σ².val) - @test typeof(σ²_from_vec(σ²_vec)) == typeof(σ².val) - - t_vec, t_from_vec = to_vec(t.val) - @test typeof(t_from_vec(t_vec)) == typeof(t.val) - end - - # Just need to ensure we can differentiate through construction properly. - if isnothing(kernel.to_vec_grad) - @test_broken false # "Gradient tests are not passing" - continue - elseif kernel.to_vec_grad - test_zygote_grad_finite_differences_compatible( - _construction_tester, - f_naive, - storage.val, - σ².val, - t.val; - check_inferred=false, - rtol=1e-6, - atol=1e-6, - ) - else - test_zygote_grad( - _construction_tester, - f_naive, - storage.val, - σ².val, - t.val; - check_inferred=false, - rtol=1e-6, - atol=1e-6, - ) - end + test_rule( + rng, _construction_tester, f_naive, storage.val, σ².val, t.val; + is_primitive=false, interface_only=true, + ) end end end diff --git a/test/models/lgssm.jl b/test/models/lgssm.jl index 19662b66..891101d8 100644 --- a/test/models/lgssm.jl +++ b/test/models/lgssm.jl @@ -16,14 +16,7 @@ using TemporalGPs: ScalarOutputLGC, Forward, Reverse, - ordering, - NoContext -using KernelFunctions -using Test -using Random: MersenneTwister -using LinearAlgebra -using StructArrays -using Zygote, StaticArrays + ordering println("lgssm:") @testset "lgssm" begin @@ -91,42 +84,28 @@ println("lgssm:") @testset "step_marginals" begin @inferred step_marginals(x, model[1]) - adjoint_test(step_marginals, (x, model[1])) - if storage.val isa SArrayStorage && TEST_ALLOC - check_adjoint_allocations(step_marginals, (x, model[1])) - end + test_rule(rng, step_marginals, x, model[1]; is_primitive=false) + end @testset "step_logpdf" begin args = (ordering(model[1]), x, (model[1], y)) @inferred step_logpdf(args...) - adjoint_test(step_logpdf, args) - if storage.val isa SArrayStorage && TEST_ALLOC - check_adjoint_allocations(step_logpdf, args) - end + test_rule(rng, step_logpdf, args...; is_primitive=false) end @testset "step_filter" begin args = (ordering(model[1]), x, (model[1], y)) @inferred step_filter(args...) - adjoint_test(step_filter, args) - if storage.val isa SArrayStorage && TEST_ALLOC - check_adjoint_allocations(step_filter, args) - end + test_rule(rng, step_filter, args...; is_primitive=false) end @testset "invert_dynamics" begin args = (x, x, model[1].transition) @inferred invert_dynamics(args...) - adjoint_test(invert_dynamics, args) - if storage.val isa SArrayStorage && TEST_ALLOC - check_adjoint_allocations(invert_dynamics, args) - end + test_rule(rng, invert_dynamics, args...; is_primitive=false) end @testset "step_posterior" begin args = (ordering(model[1]), x, (model[1], y)) @inferred step_posterior(args...) - adjoint_test(step_posterior, args) - if storage.val isa SArrayStorage && TEST_ALLOC - check_adjoint_allocations(step_posterior, args) - end + test_rule(rng, step_posterior, args...; is_primitive=false) end # Run standard battery of LGSSM tests. @@ -134,7 +113,6 @@ println("lgssm:") rng, model; rtol=1e-5, atol=1e-5, - context=NoContext(), max_primal_allocs=25, max_forward_allocs=25, max_backward_allocs=25, diff --git a/test/models/linear_gaussian_conditionals.jl b/test/models/linear_gaussian_conditionals.jl index 9e0e7ba7..78be0a39 100644 --- a/test/models/linear_gaussian_conditionals.jl +++ b/test/models/linear_gaussian_conditionals.jl @@ -104,8 +104,7 @@ println("linear_gaussian_conditionals:") # Check that everything infers and AD gives the right answer. @inferred posterior_and_lml(x, model, y_missing) - x̄ = adjoint_test(posterior_and_lml, (x, model, y_missing)) - @test x̄[2].Q isa NamedTuple{(:diag, )} + test_rule(rng, posterior_and_lml, x, model, y_missing; is_primitive=false) end end @@ -204,8 +203,7 @@ println("linear_gaussian_conditionals:") # Check that everything infers and AD gives the right answer. @inferred posterior_and_lml(x, model, y_missing) - x̄ = adjoint_test(posterior_and_lml, (x, model, y_missing)) - @test x̄[2].fan_out.Q isa NamedTuple{(:diag, )} + test_rule(rng, posterior_and_lml, x, model, y_missing; is_primitive=false) end end end diff --git a/test/models/missings.jl b/test/models/missings.jl index 3b4084e0..fb78be70 100644 --- a/test/models/missings.jl +++ b/test/models/missings.jl @@ -1,13 +1,4 @@ -using TemporalGPs: - x0, - fill_in_missings, - replace_observation_noise_cov, - transform_model_and_obs -using Random: randperm -using ChainRulesTestUtils -using Zygote: Context - -@info "missings:" +@info "missings" @testset "missings" begin rng = MersenneTwister(123456) @@ -125,10 +116,15 @@ using Zygote: Context # Only test the bits of AD that we haven't tested before. @testset "AD: transform_model_and_obs" begin - fdm = central_fdm(2, 1) - adjoint_test(fill_in_missings, (model.emissions.Q, y_missing); fdm=fdm) - adjoint_test(replace_observation_noise_cov, (model, model.emissions.Q)) - adjoint_test(transform_model_and_obs, (model, y_missing); fdm=fdm) + test_rule(rng, fill_in_missings, model.emissions.Q, y_missing; is_primitive=false) + test_rule( + rng, replace_observation_noise_cov, model, model.emissions.Q; + is_primitive=false, interface_only=true, + ) + test_rule( + rng, transform_model_and_obs, model, y_missing; + is_primitive=false, interface_only=true, + ) end end @@ -177,9 +173,8 @@ using Zygote: Context return yn_missing end - # Check logpdf and inference run, infer, and play nicely with AD. + # Check logpdf and inference run, infer. @inferred logpdf(model, y_missing) - test_zygote_grad_finite_differences_compatible(y -> logpdf(model, y) ⊢ NoTangent(), y_missing) @inferred posterior(model, y_missing) end end; diff --git a/test/models/model_test_utils.jl b/test/models/model_test_utils.jl index 5d7f3e46..34014b81 100644 --- a/test/models/model_test_utils.jl +++ b/test/models/model_test_utils.jl @@ -1,6 +1,3 @@ -using ChainRulesTestUtils: ChainRulesTestUtils, rand_tangent -using FillArrays -using Random: AbstractRNG using TemporalGPs: ArrayStorage, SArrayStorage, @@ -19,9 +16,6 @@ using TemporalGPs: LargeOutputLGC, BottleneckLGC - - - # Generation of positive semi-definite matrices. function random_vector(rng::AbstractRNG, N::Int, ::ArrayStorage{T}) where {T<:Real} @@ -92,15 +86,6 @@ function random_gaussian(rng::AbstractRNG, dim::Int, s::StorageType) return Gaussian(random_vector(rng, dim, s), random_nice_psd_matrix(rng, dim, s)) end -function ChainRulesTestUtils.rand_tangent(rng::AbstractRNG, d::T) where {T<:Gaussian} - return Tangent{T}( - m=rand_tangent(rng, d.m), - P=random_nice_psd_matrix(rng, length(d.m), storage_type(d)), - ) -end - - - # Generation of SmallOutputLGC. function random_small_output_lgc( @@ -185,16 +170,6 @@ function random_ti_gmm(rng::AbstractRNG, ordering, Dlat::Int, N::Int, s::Storage return GaussMarkovModel(ordering, As, as, Qs, x0) end -function ChainRulesTestUtils.rand_tangent(rng::AbstractRNG, gmm::T) where {T<:GaussMarkovModel} - return Tangent{T}( - ordering = nothing, - As = rand_tangent(rng, gmm.As), - as = rand_tangent(rng, gmm.as), - Qs = gmm_Qs_tangent(rng, gmm.Qs, storage_type(gmm)), - x0 = rand_tangent(rng, gmm.x0), - ) -end - function gmm_Qs_tangent( rng::AbstractRNG, Qs::T, storage_type::StorageType, ) where {T<:Vector{<:AbstractMatrix}} @@ -302,20 +277,6 @@ function random_lgssm( return LGSSM(transitions, emissions) end -function ChainRulesTestUtils.rand_tangent(rng::AbstractRNG, ssm::T) where {T<:LGSSM} - Hs = ssm.emissions.A - hs = ssm.emissions.a - Σs = ssm.emissions.Q - return Tangent{T}( - transitions = rand_tangent(rng, ssm.transitions), - emissions = Tangent{typeof(ssm.emissions)}(components=( - A=rand_tangent(rng, Hs), - a=rand_tangent(rng, hs), - Q=gmm_Qs_tangent(rng, Σs, storage_type(ssm)), - )), - ) -end - # function random_tv_scalar_lgssm(rng::AbstractRNG, Dlat::Int, N::Int, storage) # return ScalarLGSSM(random_tv_lgssm(rng, Dlat, 1, N, storage)) # end diff --git a/test/runtests.jl b/test/runtests.jl index ce79abe6..7e3bfb53 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,101 +1,46 @@ -using Test - -ENV["TESTING"] = "TRUE" - -# GROUP is an env variable from CI which can take the following values -# ["test util", "test models" "test models-lgssm" "test gp" "test space_time"] -# Select any of this to test a particular aspect. -# To test everything, simply set GROUP to "all" -# ENV["GROUP"] = "test gp" -const GROUP = get(ENV, "GROUP", "all") -OUTER_GROUP = first(split(GROUP, ' ')) - -const TEST_TYPE_INFER = false # Test type stability over the tests -const TEST_ALLOC = false # Test allocations over the tests +include("front_matter.jl") # Run the tests. -if OUTER_GROUP == "test" || OUTER_GROUP == "all" - - # Determines which group of tests should be run. - group_info = split(GROUP, ' ') - TEST_GROUP = length(group_info) == 1 ? "all" : group_info[2] - - using AbstractGPs - using BlockDiagonals - using ChainRulesCore - using ChainRulesTestUtils - using FillArrays - using FiniteDifferences - using LinearAlgebra - using KernelFunctions - using Random - using StaticArrays - using StructArrays - using TemporalGPs - - using Zygote - - using AbstractGPs: var - using TemporalGPs: AbstractLGSSM, _filter, NoContext - using Zygote: Context, _pullback - - include("test_util.jl") - - @show TEST_GROUP GROUP - - @testset "TemporalGPs.jl" begin - - if TEST_GROUP == "util" || GROUP == "all" - println("util:") - @testset "util" begin - include(joinpath("util", "harmonise.jl")) - include(joinpath("util", "scan.jl")) - include(joinpath("util", "zygote_friendly_map.jl")) - include(joinpath("util", "chainrules.jl")) - include(joinpath("util", "gaussian.jl")) - include(joinpath("util", "mul.jl")) - include(joinpath("util", "regular_data.jl")) - end - end - - if TEST_GROUP ∈ ["models", "models-lgssm", "gp", "space_time"] || GROUP == "all" - include(joinpath("models", "model_test_utils.jl")) - include(joinpath("models", "test_model_test_utils.jl")) +@testset "TemporalGPs.jl" begin + + if GROUP == "test util" + println("util:") + @testset "util" begin + include(joinpath("util", "scan.jl")) + include(joinpath("util", "zygote_friendly_map.jl")) + include(joinpath("util", "gaussian.jl")) + include(joinpath("util", "mul.jl")) + include(joinpath("util", "regular_data.jl")) end + end - if TEST_GROUP == "models" || GROUP == "all" + if GROUP == "test models" begin + @testset "models" begin println("models:") - @testset "models" begin - include(joinpath("models", "linear_gaussian_conditionals.jl")) - include(joinpath("models", "gauss_markov_model.jl")) - include(joinpath("models", "missings.jl")) - end - end - - if TEST_GROUP == "models-lgssm" || GROUP == "all" - println("models (lgssm):") - @testset "models (lgssm)" begin - include(joinpath("models", "lgssm.jl")) - end + include(joinpath("models", "test_model_test_utils.jl")) + include(joinpath("models", "linear_gaussian_conditionals.jl")) + include(joinpath("models", "gauss_markov_model.jl")) + include(joinpath("models", "lgssm.jl")) + include(joinpath("models", "missings.jl")) end + end - if TEST_GROUP == "gp" || GROUP == "all" - println("gp:") - @testset "gp" begin - include(joinpath("gp", "lti_sde.jl")) - include(joinpath("gp", "posterior_lti_sde.jl")) - end + if GROUP == "test gp" begin + println("gp:") + @testset "gp" begin + include(joinpath("gp", "lti_sde.jl")) + include(joinpath("gp", "posterior_lti_sde.jl")) end + end - if TEST_GROUP == "space_time" || GROUP == "all" - println("space_time:") - @testset "space_time" begin - include(joinpath("space_time", "rectilinear_grid.jl")) - include(joinpath("space_time", "regular_in_time.jl")) - include(joinpath("space_time", "separable_kernel.jl")) - include(joinpath("space_time", "to_gauss_markov.jl")) - include(joinpath("space_time", "pseudo_point.jl")) - end + if GROUP == "test space_time" begin + println("space_time:") + @testset "space_time" begin + include(joinpath("space_time", "rectilinear_grid.jl")) + include(joinpath("space_time", "regular_in_time.jl")) + include(joinpath("space_time", "separable_kernel.jl")) + include(joinpath("space_time", "to_gauss_markov.jl")) + include(joinpath("space_time", "pseudo_point.jl")) end end end diff --git a/test/space_time/pseudo_point.jl b/test/space_time/pseudo_point.jl index e90f0374..ac477976 100644 --- a/test/space_time/pseudo_point.jl +++ b/test/space_time/pseudo_point.jl @@ -101,8 +101,7 @@ using Test elbo_naive = elbo(VFE(f_naive(z_naive)), fx_naive, y) elbo_sde = elbo(fx, y, z_r) @test elbo_naive ≈ elbo_sde rtol=1e-6 - - test_zygote_grad_finite_differences_compatible((y, z_r) -> elbo(fx, y, z_r), y, z_r) + test_rule(rng, elbo, fx, y, z_r; is_primitive=false) # Compute approximate posterior marginals naively. f_approx_post_naive = posterior(VFE(f_naive(z_naive)), fx_naive, y) diff --git a/test/space_time/rectilinear_grid.jl b/test/space_time/rectilinear_grid.jl index fd21e76d..ba82fe63 100644 --- a/test/space_time/rectilinear_grid.jl +++ b/test/space_time/rectilinear_grid.jl @@ -1,15 +1,3 @@ -using Random -using TemporalGPs: RectilinearGrid, SpaceTimeGrid - -function FiniteDifferences.to_vec(x::RectilinearGrid) - v, tup_from_vec = to_vec((x.xl, x.xr)) - function RectilinearGrid_from_vec(v) - tup = tup_from_vec(v) - return RectilinearGrid(tup[1], tup[2]) - end - return v, RectilinearGrid_from_vec -end - @testset "rectilinear_grid" begin rng = MersenneTwister(123456) Nl = 5 diff --git a/test/space_time/regular_in_time.jl b/test/space_time/regular_in_time.jl index 2b98fe01..68e27311 100644 --- a/test/space_time/regular_in_time.jl +++ b/test/space_time/regular_in_time.jl @@ -1,5 +1,3 @@ -using TemporalGPs: RegularInTime - @testset "regular_in_time" begin T = 11 Nts = [rand(1:4) for _ in 1:T] diff --git a/test/space_time/separable_kernel.jl b/test/space_time/separable_kernel.jl index 51e72779..f4b827fd 100644 --- a/test/space_time/separable_kernel.jl +++ b/test/space_time/separable_kernel.jl @@ -1,6 +1,3 @@ -using Random -using TemporalGPs: RectilinearGrid, Separable - @testset "separable_kernel" begin rng = MersenneTwister(123456) diff --git a/test/space_time/to_gauss_markov.jl b/test/space_time/to_gauss_markov.jl index 002bb9d5..18de0ee6 100644 --- a/test/space_time/to_gauss_markov.jl +++ b/test/space_time/to_gauss_markov.jl @@ -1,5 +1,3 @@ -using TemporalGPs: RectilinearGrid, Separable, is_of_storage_type - @testset "to_gauss_markov" begin rng = MersenneTwister(123456) Nr = 3 @@ -7,13 +5,12 @@ using TemporalGPs: RectilinearGrid, Separable, is_of_storage_type Nt_pr = 2 @testset "restructure" begin - adjoint_test( - x -> TemporalGPs.restructure(x, [26, 24, 20, 30]), (randn(100), ); - check_inferred=false, + test_rule( + rng, TemporalGPs.restructure, randn(100), [26, 24, 20, 30]; is_primitive=false ) - adjoint_test( - x -> TemporalGPs.restructure(x, [26, 24, 20, 30]), (Fill(randn(), 100), ); - check_inferred=false, + test_rule( + rng, TemporalGPs.restructure, Fill(randn(), 100), [26, 24, 20, 30]; + is_primitive=false, ) end @@ -89,33 +86,16 @@ using TemporalGPs: RectilinearGrid, Separable, is_of_storage_type # No statistical tests run on `rand`, which seems somewhat dangerous, but there's # not a lot to be done about it unfortunately. @testset "rand" begin - y = rand(rng, fx_post_sde) - @test y isa AbstractVector{<:Real} - @test length(y) == length(x_pr) + _y = rand(rng, fx_post_sde) + @test _y isa AbstractVector{<:Real} + @test length(_y) == length(x_pr) end end - # # I'm not checking correctness here, just that it runs. No custom adjoints have been - # # written that are involved in this that aren't tested, so there should be no need - # # to check correctness. - # @testset "logpdf AD" begin - # out, pb = Zygote._pullback(NoContext(), logpdf, ft_sde, y) - # pb(rand_zygote_tangent(out)) - # end - # # adjoint_test(logpdf, (ft_sde, y); fdm=central_fdm(2, 1), check_inferred=false) - - # if t.val isa RegularSpacing - # adjoint_test( - # (r, Δt, y) -> begin - # x = RectilinearGrid(r, RegularSpacing(t.val.t0, Δt, Nt)) - # _f = to_sde(GP(k.val, GPC())) - # _ft = _f(x, σ².val...) - # return logpdf(_ft, y) - # end, - # (r, t.val.Δt, y_sde); - # check_inferred=false, - # ) - # end + # I'm not checking correctness here, just that it runs. No custom adjoints have been + # written that are involved in this that aren't tested, so there should be no need + # to check correctness. + test_rule(rng, logpdf, ft_sde, y; is_primitive=false, interface_only=true) end end diff --git a/test/test_util.jl b/test/test_util.jl index 083f930a..f5892efb 100644 --- a/test/test_util.jl +++ b/test/test_util.jl @@ -1,19 +1,7 @@ -using AbstractGPs -using BlockDiagonals -using ChainRulesCore: backing, ZeroTangent, NoTangent, Tangent -using ChainRulesTestUtils: ChainRulesTestUtils, test_approx, rand_tangent, test_rrule, ⊢, @ignore_derivatives -using FiniteDifferences -using FillArrays -using LinearAlgebra -using Random: AbstractRNG, MersenneTwister -using StaticArrays -using StructArrays -using TemporalGPs using TemporalGPs: AbstractLGSSM, ElementOfLGSSM, Gaussian, - harmonise, Forward, Reverse, GaussMarkovModel, @@ -30,486 +18,26 @@ using TemporalGPs: x0, scan_emit, ε_randn -using Test -using Zygote -using Zygote: Context - - - -# Make FiniteDifferences work with some of the types in this package. Shame this isn't -# automated... - -import FiniteDifferences: to_vec - -test_zygote_grad(f, args...; check_inferred=false, kwargs...) = test_rrule(Zygote.ZygoteRuleConfig(), f, args...; rrule_f=rrule_via_ad, check_inferred, kwargs...) - -function test_zygote_grad_finite_differences_compatible(f, args...; kwargs...) - x_vec, from_vec = to_vec(args) - function finite_diff_compatible_f(x::AbstractVector) - return @ignore_derivatives(f)(@ignore_derivatives(from_vec)(x)...) - end - test_zygote_grad(finite_diff_compatible_f ⊢ NoTangent(), x_vec; testset_name="test_rrule: $(f) on $(typeof.(args))", kwargs...) -end - -function to_vec(x::Fill) - x_vec, back_vec = to_vec(FillArrays.getindex_value(x)) - function Fill_from_vec(x_vec) - return Fill(back_vec(x_vec), axes(x)) - end - return x_vec, Fill_from_vec -end - -function to_vec(x::Union{Zeros, Ones}) - return Vector{eltype(x)}(undef, 0), _ -> x -end - -# I'M OVERRIDING FINITEDIFFERENCES DEFINITION HERE. THIS IS BAD. -function to_vec(x::Diagonal) - v, diag_from_vec = to_vec(x.diag) - Diagonal_from_vec(v) = Diagonal(diag_from_vec(v)) - return v, Diagonal_from_vec -end - -# function to_vec(x::T) where {T<:NamedTuple} -# isempty(fieldnames(T)) && throw(error("Expected some fields. None found.")) -# vecs_and_backs = map(name->to_vec(getfield(x, name)), fieldnames(T)) -# vecs, backs = first.(vecs_and_backs), last.(vecs_and_backs) -# x_vec, back = to_vec(vecs) -# function namedtuple_to_vec(x′_vec) -# vecs′ = back(x′_vec) -# x′s = map((back, vec)->back(vec), backs, vecs′) -# return (; zip(fieldnames(T), x′s)...) -# end -# return x_vec, namedtuple_to_vec -# end - -function to_vec(x::T) where {T<:StaticArray} - x_dense = collect(x) - x_vec, back_vec = to_vec(x_dense) - function StaticArray_to_vec(x_vec) - return T(back_vec(x_vec)) - end - return x_vec, StaticArray_to_vec -end - -function to_vec(x::Adjoint{<:Any, T}) where {T<:StaticVector} - x_vec, back = to_vec(Matrix(x)) - Adjoint_from_vec(x_vec) = Adjoint(T(conj!(vec(back(x_vec))))) - return x_vec, Adjoint_from_vec -end - -function to_vec(::Tuple{}) - empty_tuple_from_vec(::AbstractVector) = () - return Bool[], empty_tuple_from_vec -end - -function to_vec(x::StructArray{T}) where {T} - x_vec, x_fields_from_vec = to_vec(StructArrays.components(x)) - function StructArray_from_vec(x_vec) - x_field_vecs = x_fields_from_vec(x_vec) - return StructArray{T}(Tuple(x_field_vecs)) - end - return x_vec, StructArray_from_vec -end - -function to_vec(x::TemporalGPs.LGSSM) - x_vec, from_vec = to_vec((x.transitions, x.emissions)) - function LGSSM_from_vec(x_vec) - (transition, emission) = from_vec(x_vec) - return LGSSM(transition, emission) - end - return x_vec, LGSSM_from_vec -end - -function to_vec(x::ElementOfLGSSM) - x_vec, from_vec = to_vec((x.transition, x.emission)) - function ElementOfLGSSM_from_vec(x_vec) - (transition, emission) = from_vec(x_vec) - return ElementOfLGSSM(x.ordering, transition, emission) - end - return x_vec, ElementOfLGSSM_from_vec -end - -function ChainRulesTestUtils.test_approx(actual::Tangent{<:Fill}, expected, msg=""; kwargs...) - test_approx(actual.value, expected.value, msg; kwargs...) -end - -function to_vec(x::PeriodicKernel) - x, to_r = to_vec(x.r) - function PeriodicKernel_from_vec(x) - return PeriodicKernel(;r=exp.(to_r(x))) - end - log.(x), PeriodicKernel_from_vec -end - -to_vec(x::T) where {T} = generic_struct_to_vec(x) - -# This is a copy from FiniteDifferences.jl without the try catch -function generic_struct_to_vec(x::T) where {T} - Base.isstructtype(T) || throw(error("Expected a struct type")) - isempty(fieldnames(T)) && return (Bool[], _ -> x) # Singleton types - val_vecs_and_backs = map(name -> to_vec(getfield(x, name)), fieldnames(T)) - vals = first.(val_vecs_and_backs) - backs = last.(val_vecs_and_backs) - v, vals_from_vec = to_vec(vals) - function structtype_from_vec(v::Vector{<:Real}) - val_vecs = vals_from_vec(v) - vals = map((b, v) -> b(v), backs, val_vecs) - return T(vals...) - end - return v, structtype_from_vec -end - -to_vec(x::TemporalGPs.RectilinearGrid) = generic_struct_to_vec(x) - -function to_vec(x::AbstractRNG) - return Bool[], _ -> x -end - -Base.zero(x::AbstractRNG) = x - -function to_vec(f::GP) - gp_vec, t_from_vec = to_vec((f.mean, f.kernel)) - function GP_from_vec(v) - m, k = t_from_vec(v) - return GP(m, k) - end - return gp_vec, GP_from_vec -end - -function to_vec(k::ConstantKernel) - c, c_to_vec = to_vec(k.c) - function ConstantKernel_from_vec(c) - return ConstantKernel(c=first(c_to_vec(c))) - end - c, ConstantKernel_from_vec -end - -Base.zero(x::AbstractGPs.ZeroMean) = x -Base.zero(x::Kernel) = x -Base.zero(x::TemporalGPs.LTISDE) = x -Base.zero(x::GP) = x -Base.zero(x::AbstractGPs.MeanFunction) = x - -function to_vec(X::BlockDiagonal) - Xs = blocks(X) - Xs_vec, Xs_from_vec = to_vec(Xs) - - function BlockDiagonal_from_vec(Xs_vec) - Xs = Xs_from_vec(Xs_vec) - return BlockDiagonal(Xs) - end - - return Xs_vec, BlockDiagonal_from_vec -end - -function to_vec(x::RegularSpacing) - RegularSpacing_from_vec(v) = RegularSpacing(v[1], v[2], x.N) - return [x.t0, x.Δt], RegularSpacing_from_vec -end - -# Ensure that to_vec works for the types that we care about in this package. -@testset "custom FiniteDifferences stuff" begin - @testset "NamedTuple" begin - a, b = 5.0, randn(2) - t = (a=a, b=b) - nt_vec, back = to_vec(t) - @test nt_vec isa Vector{Float64} - @test back(nt_vec) == t - end - @testset "Fill" begin - @testset "$(typeof(val))" for val in [5.0, randn(3)] - x = Fill(val, 5) - x_vec, back = to_vec(x) - @test x_vec isa Vector{Float64} - @test back(x_vec) == x - end - end - @testset "Zeros{T}" for T in [Float32, Float64] - x = Zeros{T}(4) - x_vec, back = to_vec(x) - @test x_vec isa Vector{eltype(x)} - @test back(x_vec) == x - end - @testset "gaussian" begin - @testset "Gaussian" begin - x = TemporalGPs.Gaussian(randn(3), randn(3, 3)) - x_vec, back = to_vec(x) - @test back(x_vec) == x - end - end - @testset "to_vec(::SmallOutputLGC)" begin - A = randn(2, 2) - a = randn(2) - Q = randn(2, 2) - model = SmallOutputLGC(A, a, Q) - model_vec, model_from_vec = to_vec(model) - @test model_vec isa Vector{<:Real} - @test model_from_vec(model_vec) == model - end - @testset "to_vec(::GaussMarkovModel)" begin - N = 11 - A = [randn(2, 2) for _ in 1:N] - a = [randn(2) for _ in 1:N] - Q = [randn(2, 2) for _ in 1:N] - H = [randn(3, 2) for _ in 1:N] - h = [randn(3) for _ in 1:N] - x0 = TemporalGPs.Gaussian(randn(2), randn(2, 2)) - gmm = TemporalGPs.GaussMarkovModel(Forward(), A, a, Q, x0) - - gmm_vec, gmm_from_vec = to_vec(gmm) - @test gmm_vec isa Vector{<:Real} - @test gmm_from_vec(gmm_vec) == gmm - end - @testset "StructArray" begin - x = StructArray([Gaussian(randn(2), randn(2, 2)) for _ in 1:10]) - x_vec, x_from_vec = to_vec(x) - @test x_vec isa Vector{<:Real} - @test x_from_vec(x_vec) == x - end - @testset "to_vec(::LGSSM)" begin - N = 11 - - # Build GaussMarkovModel. - A = [randn(2, 2) for _ in 1:N] - a = [randn(2) for _ in 1:N] - Q = [randn(2, 2) for _ in 1:N] - x0 = Gaussian(randn(2), randn(2, 2)) - gmm = GaussMarkovModel(Forward(), A, a, Q, x0) - - # Build LGSSM. - H = [randn(3, 2) for _ in 1:N] - h = [randn(3) for _ in 1:N] - Σ = [randn(3, 3) for _ in 1:N] - model = TemporalGPs.LGSSM(gmm, StructArray(map(SmallOutputLGC, H, h, Σ))) - - model_vec, model_from_vec = to_vec(model) - @test model_from_vec(model_vec) == model - end - @testset "to_vec(::BlockDiagonal)" begin - Ns = [3, 5, 1] - Xs = map(N -> randn(N, N), Ns) - X = BlockDiagonal(Xs) - - X_vec, X_from_vec = to_vec(X) - @test X_vec isa Vector{<:Real} - @test X_from_vec(X_vec) == X - end -end - -my_zero(x) = zero(x) -my_zero(x::AbstractArray{<:Real}) = zero(x) -my_zero(x::AbstractArray) = map(my_zero, x) -my_zero(x::Tuple) = map(my_zero, x) - -# My version of isapprox -function fd_isapprox(x_ad::Nothing, x_fd, rtol, atol) - return fd_isapprox(x_fd, my_zero(x_fd), rtol, atol) -end -function fd_isapprox(x_ad::AbstractArray, x_fd::AbstractArray, rtol, atol) - return all(fd_isapprox.(x_ad, x_fd, rtol, atol)) -end -function fd_isapprox(x_ad::Real, x_fd::Real, rtol, atol) - return isapprox(x_ad, x_fd; rtol=rtol, atol=atol) -end -function fd_isapprox(x_ad::NamedTuple, x_fd, rtol, atol) - f = (x_ad, x_fd)->fd_isapprox(x_ad, x_fd, rtol, atol) - return all([f(getfield(x_ad, key), getfield(x_fd, key)) for key in keys(x_ad)]) -end -function fd_isapprox(x_ad::Tuple, x_fd::Tuple, rtol, atol) - return all(map((x, x′)->fd_isapprox(x, x′, rtol, atol), x_ad, x_fd)) -end -function fd_isapprox(x_ad::Dict, x_fd::Dict, rtol, atol) - return all([fd_isapprox(get(()->nothing, x_ad, key), x_fd[key], rtol, atol) for - key in keys(x_fd)]) -end -function fd_isapprox(x::Gaussian, y::Gaussian, rtol, atol) - return isapprox(x.m, y.m; rtol=rtol, atol=atol) && - isapprox(x.P, y.P; rtol=rtol, atol=atol) -end -function fd_isapprox(x::Real, y::ZeroTangent, rtol, atol) - return fd_isapprox(x, zero(x), rtol, atol) -end -fd_isapprox(x::ZeroTangent, y::Real, rtol, atol) = fd_isapprox(y, x, rtol, atol) - -function fd_isapprox(x_ad::T, x_fd::T, rtol, atol) where {T<:NamedTuple} - f = (x_ad, x_fd)->fd_isapprox(x_ad, x_fd, rtol, atol) - return all([f(getfield(x_ad, key), getfield(x_fd, key)) for key in keys(x_ad)]) -end - -function fd_isapprox(x::T, y::T, rtol, atol) where {T} - if !isstructtype(T) - throw(ArgumentError("Non-struct types are not supported by this fallback.")) - end - - return all(n -> fd_isapprox(getfield(x, n), getfield(y, n), rtol, atol), fieldnames(T)) -end - -function adjoint_test( - f, ȳ, x::Tuple, ẋ::Tuple; - rtol=1e-6, - atol=1e-6, - fdm=central_fdm(5, 1; max_range=1e-3), - test=true, - check_inferred=TEST_TYPE_INFER, - context=Context(), - kwargs..., -) - # Compute = using Zygote. - y, pb = Zygote.pullback(f, x...) - - # Check type inference if requested. - if check_inferred - # @descend only works if you `using Cthulhu`. - # @descend Zygote._pullback(context, f, x...) - # @descend pb(ȳ) - - # @code_warntype Zygote._pullback(context, f, x...) - # @code_warntype pb(ȳ) - @inferred Zygote._pullback(context, f, x...) - @inferred pb(ȳ) - end - x̄ = pb(ȳ) - x̄_ad, ẋ_ad = harmonise(Zygote.wrap_chainrules_input(x̄), ẋ) - inner_ad = dot(x̄_ad, ẋ_ad) - - # Approximate = using FiniteDifferences. - # x̄_fd = j′vp(fdm, f, ȳ, x...) - ẏ = jvp(fdm, f, zip(x, ẋ)...) - - ȳ_fd, ẏ_fd = harmonise(Zygote.wrap_chainrules_input(ȳ), ẏ) - inner_fd = dot(ȳ_fd, ẏ_fd) - # Check that Zygote didn't modify the forwards-pass. - test && @test fd_isapprox(y, f(x...), rtol, atol) - - # Check for approximate agreement in "inner-products". - test && @test fd_isapprox(inner_ad, inner_fd, rtol, atol) - - return x̄ -end - -function adjoint_test(f, input::Tuple; kwargs...) - Δoutput = rand_zygote_tangent(f(input...)) - return adjoint_test(f, Δoutput, input; kwargs...) -end - -function adjoint_test(f, Δoutput, input::Tuple; kwargs...) - ∂input = map(rand_zygote_tangent, input) - return adjoint_test(f, Δoutput, input, ∂input; kwargs...) -end - -function print_adjoints(adjoint_ad, adjoint_fd, rtol, atol) - @show typeof(adjoint_ad), typeof(adjoint_fd) - - # println("ad") - # display(adjoint_ad) - # println() - - # println("fd") - # display(adjoint_fd) - # println() - - adjoint_ad, adjoint_fd = to_vec(adjoint_ad)[1], to_vec(adjoint_fd)[1] - println("atol is $atol, rtol is $rtol") - println("ad, fd, abs, rel") - abs_err = abs.(adjoint_ad .- adjoint_fd) - rel_err = abs_err ./ adjoint_ad - display([adjoint_ad adjoint_fd abs_err rel_err]) - println() -end - -using BenchmarkTools - -# Also checks the forwards-pass because it's helpful. -function check_adjoint_allocations( - f, Δoutput, input::Tuple; - context=NoContext(), - max_primal_allocs=0, - max_forward_allocs=0, - max_backward_allocs=0, - kwargs..., -) - _, pb = _pullback(context, f, input...) - - primal_allocs = allocs(@benchmark($f($input...); samples=1, evals=1)) - forward_allocs = allocs( - @benchmark(_pullback($context, $f, $input...); samples=1, evals=1), - ) - backward_allocs = allocs(@benchmark $pb($Δoutput) samples=1 evals=1) - - # primal_allocs = allocs(@benchmark($f($input...))) - # forward_allocs = allocs( - # @benchmark(_pullback($context, $f, $input...)), - # ) - # backward_allocs = allocs(@benchmark $pb($Δoutput)) - - # @show primal_allocs - # @show forward_allocs - # @show backward_allocs - - @test primal_allocs <= max_primal_allocs - @test forward_allocs <= max_forward_allocs - @test backward_allocs <= max_backward_allocs -end - -function check_adjoint_allocations(f, input::Tuple; kwargs...) - return check_adjoint_allocations(f, rand_zygote_tangent(f(input...)), input; kwargs...) -end - -function benchmark_adjoint(f, ȳ, args...; disp=false) - disp && println("primal") - primal = @benchmark($f($args...); samples=1, evals=1) - if disp - display(primal) - println() - end - - disp && println("pullback generation") - forward_pass = @benchmark(Zygote.pullback($f, $args...); samples=1, evals=1) - if disp - display(forward_pass) - println() - end - - y, back = Zygote.pullback(f, args...) - - disp && println("pullback evaluation") - reverse_pass = @benchmark($back($ȳ); samples=1, evals=1) - if disp - display(reverse_pass) - println() - end - - return primal, forward_pass, reverse_pass -end function test_interface( rng::AbstractRNG, conditional::AbstractLGC, x::Gaussian; - check_inferred=TEST_TYPE_INFER, check_adjoints=true, check_allocs=TEST_ALLOC, atol=1e-6, rtol=1e-6, kwargs..., + check_inferred=TEST_TYPE_INFER, check_adjoints=true, check_allocs=TEST_ALLOC, kwargs..., ) x_val = rand(rng, x) y = conditional_rand(rng, conditional, x_val) + perf_flag = check_allocs ? :allocs : :none @testset "rand" begin @test length(y) == dim_out(conditional) args = (TemporalGPs.ε_randn(rng, conditional), conditional, x_val) check_inferred && @inferred conditional_rand(args...) - if check_adjoints - test_zygote_grad( - conditional_rand, args...; - check_inferred, rtol, atol, - ) - end - if check_allocs - check_adjoint_allocations(conditional_rand, args; kwargs...) - end + check_adjoints && test_rule(rng, conditional_rand, args...; perf_flag, is_primitive=false) end @testset "predict" begin @test predict(x, conditional) isa Gaussian check_inferred && @inferred predict(x, conditional) - check_adjoints && adjoint_test(predict, (x, conditional); kwargs...) - check_allocs && check_adjoint_allocations(predict, (x, conditional); kwargs...) + check_adjoints && test_rule(rng, predict, x, conditional; perf_flag, is_primitive=false) end conditional isa ScalarOutputLGC || @testset "predict_marginals" begin @@ -525,21 +53,7 @@ function test_interface( args = (x, conditional, y) @test posterior_and_lml(args...) isa Tuple{Gaussian, Real} check_inferred && @inferred posterior_and_lml(args...) - if check_adjoints - (Δx, Δlml) = rand_zygote_tangent(posterior_and_lml(args...)) - ∂args = map(rand_tangent, args) - adjoint_test(posterior_and_lml, (Δx, Δlml), args, ∂args) - adjoint_test(posterior_and_lml, (Δx, nothing), args, ∂args) - adjoint_test(posterior_and_lml, (nothing, Δlml), args, ∂args) - adjoint_test(posterior_and_lml, (nothing, nothing), args, ∂args) - end - if check_allocs - (Δx, Δlml) = rand_zygote_tangent(posterior_and_lml(args...)) - check_adjoint_allocations(posterior_and_lml, (Δx, Δlml), args; kwargs...) - check_adjoint_allocations(posterior_and_lml, (nothing, Δlml), args; kwargs...) - check_adjoint_allocations(posterior_and_lml, (Δx, nothing), args; kwargs...) - check_adjoint_allocations(posterior_and_lml, (nothing, nothing), args; kwargs...) - end + check_adjoints && test_rule(rng, posterior_and_lml, args...; perf_flag, is_primitive=false) end end @@ -557,6 +71,7 @@ function test_interface( rng::AbstractRNG, ssm::AbstractLGSSM; check_inferred=TEST_TYPE_INFER, check_adjoints=true, check_allocs=TEST_ALLOC, rtol, atol, kwargs... ) + perf_flag = check_allocs ? :allocs : :none y_no_missing = rand(rng, ssm) @testset "LGSSM interface" begin @testset "rand" begin @@ -565,19 +80,7 @@ function test_interface( @test length(y_no_missing) == length(ssm) check_inferred && @inferred rand(rng, ssm) rng = MersenneTwister(123456) - if check_adjoints - # We need the whole scan_emit machinery to test the adjoint of rand - @test_broken 1 == 0 - # It seems test_rrule cannot deal good with `rng` at the moment - # test_zygote_grad(rng, ssm; check_inferred, rtol, atol) do rng, model - # iterable = zip(ε_randn(rng, model), model) - # init = rand(rng, x0(model)) - # return scan_emit(step_rand, iterable, init, eachindex(model)) - # end - end - if check_allocs - check_adjoint_allocations(rand, (rng, ssm); kwargs...) - end + check_adjoints && test_rule(rng, rand, rng, ssm; perf_flag, interface_only=true, is_primitive=false) end @testset "basics" begin @@ -592,13 +95,10 @@ function test_interface( @test length(xs) == length(ssm) check_inferred && @inferred marginals(ssm) if check_adjoints - # We need to test the whole scan_emit to avoid throwing a state. - test_zygote_grad(ssm; check_inferred, rtol, atol) do model - scan_emit(step_marginals, model, x0(model), eachindex(model)) - end - end - if check_allocs - check_adjoint_allocations(marginals, (ssm, ); kwargs...) + test_rule( + rng, scan_emit, step_marginals, ssm, x0(ssm), eachindex(ssm); + perf_flag, is_primitive=false, interface_only=true, + ) end end @@ -615,9 +115,10 @@ function test_interface( @test is_of_storage_type(lml, storage_type(ssm)) _check_inferred && @inferred logpdf(ssm, y) if check_adjoints - test_zygote_grad(ssm, y; check_inferred, rtol, atol) do model, y - scan_emit(step_logpdf, zip(model, y), x0(model), eachindex(model)) - end + test_rule( + rng, scan_emit, step_logpdf, zip(ssm, y), x0(ssm), eachindex(ssm); + perf_flag, is_primitive=false, interface_only=true, + ) end end @testset "_filter" begin @@ -627,9 +128,10 @@ function test_interface( @test length(xs) == length(ssm) _check_inferred && @inferred _filter(ssm, y) if check_adjoints - test_zygote_grad(ssm, y; check_inferred, rtol, atol) do model, y - scan_emit(step_filter, zip(model, y), x0(model), eachindex(model)) - end + test_rule( + rng, scan_emit, step_filter, zip(ssm, y), x0(ssm), eachindex(ssm); + perf_flag, is_primitive=false, interface_only=true, + ) end end @testset "posterior" begin @@ -638,53 +140,12 @@ function test_interface( @test ordering(posterior_ssm) != ordering(ssm) _check_inferred && @inferred posterior(ssm, y) if check_adjoints - test_zygote_grad(posterior, ssm, y; check_inferred, rtol, atol) - end - end - - # Hack to only run the AD tests if requested. - @testset "adjoints" for _ in (check_adjoints ? [1] : []) - if check_allocs - check_adjoint_allocations(_filter, (ssm, y); kwargs...) - check_adjoint_allocations(posterior, (ssm, y); kwargs...) + test_rule( + rng, posterior, ssm, y; + perf_flag, is_primitive=false, interface_only=true, + ) end end end end end - -# This is unfortunately needed to make ChainRulesTestUtils comparison works. -# See https://github.com/JuliaDiff/ChainRulesTestUtils.jl/issues/271 -Base.zero(::Forward) = Forward() -Base.zero(::Reverse) = Reverse() - -_diag(x) = diag(x) -_diag(x::Real) = x - -function FiniteDifferences.rand_tangent(rng::AbstractRNG, A::StaticArray) - return map(x -> rand_tangent(rng, x), A) -end - -FiniteDifferences.rand_tangent(::AbstractRNG, ::Base.OneTo) = ZeroTangent() - -# Hacks to make rand_tangent play nicely with Zygote. -rand_zygote_tangent(A) = Zygote.wrap_chainrules_output(FiniteDifferences.rand_tangent(A)) - -Zygote.wrap_chainrules_output(x::Array) = map(Zygote.wrap_chainrules_output, x) - -function Zygote.wrap_chainrules_input(x::Array) - return map(Zygote.wrap_chainrules_input, x) -end - -function LinearAlgebra.dot(A::Tangent, B::Tangent) - mutual_names = intersect(propertynames(A), propertynames(B)) - if length(mutual_names) == 0 - return 0 - else - return sum(n -> dot(getproperty(A, n), getproperty(B, n)), mutual_names) - end -end - -function ChainRulesTestUtils.test_approx(actual::Tangent{T}, expected::StructArray, msg=""; kwargs...) where {T<:StructArray} - return test_approx(actual.components, expected; kwargs...) -end \ No newline at end of file diff --git a/test/util/chainrules.jl b/test/util/chainrules.jl deleted file mode 100644 index b68ab30c..00000000 --- a/test/util/chainrules.jl +++ /dev/null @@ -1,117 +0,0 @@ -using StaticArrays -using BenchmarkTools -using BlockDiagonals -using ChainRulesCore -using ChainRulesTestUtils -using Test -using TemporalGPs -using TemporalGPs: time_exp, _map, Gaussian -using FillArrays -using StructArrays -using Zygote: ZygoteRuleConfig - -@testset "chainrules" begin - @testset "StaticArrays" begin - @testset "SArray constructor" begin - for (f, x) in ( - (SArray{Tuple{3, 2, 1}}, ntuple(i -> 2.5i, 6)), - (SVector{5}, (ntuple(i -> 2.5i, 5))), - (SVector{2}, (2.0, 1.0)), - (SMatrix{5, 4}, (ntuple(i -> 2.5i, 20))), - (SMatrix{1, 1}, (randn(),)) - ) - test_rrule(ZygoteRuleConfig(), f, x; rrule_f=rrule_via_ad) - end - end - @testset "collect(::SArray)" begin - A = SArray{Tuple{3, 1, 2}}(ntuple(i -> 3.5i, 6)) - test_rrule(collect, A) - end - @testset "vcat(::SVector, ::SVector)" begin - a = SVector{3}(randn(3)) - b = SVector{2}(randn(2)) - test_rrule(vcat, a, b) - end - end - @testset "time_exp" begin - A = randn(3, 3) - test_rrule(time_exp, A ⊢ NoTangent(), 0.1) - end - @testset "Fill" begin - @testset "Fill constructor" begin - for x in ( - randn(), - randn(1, 2), - SMatrix{1, 2}(randn(1, 2)), - ) - test_rrule(Fill, x, 3; check_inferred=false) - test_rrule(Fill, x, (3, 4); check_inferred=false) - end - end - @testset "collect(::Fill)" begin - P = 11 - Q = 3 - @testset "$(typeof(x)) element" for x in [ - randn(), - randn(1, 2), - SMatrix{1, 2}(randn(1, 2)), - ] - test_rrule(collect, Fill(x, P)) - # The test rule does not work due to inconsistencies of FiniteDifferencies for FillArrays - test_rrule(collect, Fill(x, P, Q)) - end - end - end - - # The rrule is not even used... - @testset "getindex(::Fill, ::Int)" begin - X = Fill(randn(5, 3), 10) - test_rrule(getindex, X, 3; check_inferred=false) - end - @testset "BlockDiagonal" begin - X = map(N -> randn(N, N), [3, 4, 1]) - test_rrule(BlockDiagonal, X) - end - @testset "_map(f, x::Fill)" begin - x = Fill(randn(3, 4), 4) - test_rrule(_map, sum, x; check_inferred=false) - test_rrule(_map, x->map(sin, x), x; check_inferred=false) - test_rrule(_map, x -> 2.0 * x, x; check_inferred=false) - test_rrule(ZygoteRuleConfig(), (x,a)-> _map(x -> x * a, x), x, 2.0; check_inferred=false, rrule_f=rrule_via_ad) - end - @testset "_map(f, x::Fill....)" begin - x1 = Fill(randn(3, 4), 3) - x2 = Fill(randn(3, 4), 3) - x3 = Fill(randn(3, 4), 3) - - @test _map(+, x1, x2) == _map(+, collect(x1), collect(x2)) - test_rrule(_map, +, x1, x2; check_inferred=true) - - @test _map(+, x1, x2, x3) == _map(+, collect(x1), collect(x2), collect(x3)) - test_rrule(_map, +, x1, x2, x3; check_inferred=true) - - fsin(x, y) = sin.(x .* y) - test_rrule(_map, fsin, x1, x2; check_inferred=false) - - foo(a, x1, x2) = _map((z1, z2) -> a * sin.(z1 .* z2), x1, x2) - test_rrule(ZygoteRuleConfig(), foo, randn(), x1, x2; check_inferred=false, rrule_f=rrule_via_ad) - end - @testset "StructArray" begin - a = randn(5) - b = rand(5) - # This test is broken due to FiniteDifferences returning the wrong Tangent. - @test_broken 1 == 0 - # test_rrule(StructArray, (a, b); check_inferred=false) - - xs = [Gaussian(randn(1), randn(1, 1)) for _ in 1:2] - ms = getfield.(xs, :m) - Ps = getfield.(xs, :P) - # Same here. - @test_broken 1 == 0 - # test_rrule(StructArray{eltype(xs)}, (ms, Ps)) - xs_sa = StructArray{eltype(xs)}((ms, Ps)) - # And here. - @test_broken 1 == 0 - # test_zygote_grad(getproperty, xs_sa, :m) - end -end diff --git a/test/util/gaussian.jl b/test/util/gaussian.jl index 18e7faa9..e5bff28c 100644 --- a/test/util/gaussian.jl +++ b/test/util/gaussian.jl @@ -1,13 +1,3 @@ -using TemporalGPs: Gaussian - -# This is a ridiculous definition that makes no sense. Don't use this anywhere. -Base.zero(x::Gaussian) = Gaussian(zero(x.m), zero(x.P)) - -function fd_isapprox(x_ad::Gaussian, x_fd::Gaussian, rtol, atol) - return fd_isapprox(x_ad.m, x_fd.m, rtol, atol) && - fd_isapprox(x_ad.P, x_fd.P, rtol, atol) -end - @testset "Gaussian" begin N = 11 @test TemporalGPs.dim(Gaussian(randn(N), randn(N, N))) == N diff --git a/test/util/harmonise.jl b/test/util/harmonise.jl deleted file mode 100644 index 50ae7733..00000000 --- a/test/util/harmonise.jl +++ /dev/null @@ -1,57 +0,0 @@ -using TemporalGPs: are_harmonised - -function test_harmonise(a, b; recurse=true) - h = harmonise(a, b) - @test h isa Tuple - @test length(h) == 2 - @test are_harmonised(h[1], h[2]) - - recurse && test_harmonise(b, a; recurse=false) - h′ = harmonise(b, a) - @test h isa Tuple - @test length(h) == 2 - @test are_harmonised(h′[1], h′[2]) - @test are_harmonised(h[1], h′[1]) - @test are_harmonised(h[1], h′[2]) -end - -@testset "harmonise" begin - test_harmonise(5.0, 4.0) - - @testset "AbstractZero" begin - test_harmonise(5.0, ZeroTangent()) - test_harmonise(ZeroTangent(), randn(10)) - test_harmonise(ZeroTangent(), ZeroTangent()) - end - - @testset "Array" begin - test_harmonise(randn(5), randn(5)) - test_harmonise( - [(randn(), randn()) for _ in 1:10], - [Tangent{Any}(randn(), rand()) for _ in 1:10], - ) - end - - @testset "Tuple / Tangent{Tuple}" begin - test_harmonise((5, 4), (5, 4)) - test_harmonise(Tangent{Tuple}(5, 4), (5, 4)) - test_harmonise(Tangent{Tuple}(5, 4), Tangent{Tuple}(5, 4)) - - test_harmonise((5, Tangent{Tuple}(randn(5))), (5, (randn(5), ))) - test_harmonise( - Tangent{Any}(Tangent{Any}(randn(5))), - (Tangent{Any}(randn(5)), ), - ) - end - - @testset "NamedTuple / Tangent{NamedTuple}" begin - test_harmonise(Tangent{Any}(; m=4, P=5), Tangent{Gaussian}(; m=5, P=4)) - test_harmonise(Tangent{Any}(; m=4, P=5), Tangent{Any}(; m=4)) - test_harmonise(Tangent{Any}(; m=5), Tangent{Any}(; P=4)) - - test_harmonise(Tangent{Any}(; m=(5, 4)), Tangent{Any}(; P=4)) - - test_harmonise(Tangent{Any}(; m=5, P=4), Gaussian(5, 4)) - test_harmonise(Tangent{Any}(; P=4), Gaussian(4, 5)) - end -end diff --git a/test/util/mul.jl b/test/util/mul.jl index d91df46f..ef68c8b6 100644 --- a/test/util/mul.jl +++ b/test/util/mul.jl @@ -1,6 +1,3 @@ -using Random: MersenneTwister -using LinearAlgebra: mul! - @testset "mul" begin rng = MersenneTwister(123456) P = 50 diff --git a/test/util/regular_data.jl b/test/util/regular_data.jl index 9cc2c042..f56bb0e6 100644 --- a/test/util/regular_data.jl +++ b/test/util/regular_data.jl @@ -1,13 +1,3 @@ -using FiniteDifferences -using Zygote - -function FiniteDifferences.to_vec(x::RegularSpacing) - function from_vec_RegularSpacing(x_vec) - return RegularSpacing(x_vec[1], x_vec[2], x.N) - end - return [x.t0, x.Δt], from_vec_RegularSpacing -end - @testset "regular_data" begin t0 = randn() Δt = randn() @@ -20,14 +10,4 @@ end @test collect(x) ≈ collect(x_range) @test step(x) == step(x_range) @test length(x) == length(x_range) - - let - x, back = Zygote.pullback(RegularSpacing, t0, Δt, N) - - Δ_t0 = randn() - Δ_Δt = randn() - @test back((t0 = Δ_t0, Δt = Δ_Δt, N=nothing)) == (Δ_t0, Δ_Δt, nothing) - - test_rrule(RegularSpacing, randn(), rand(), 10; output_tangent=Tangent{RegularSpacing}(Δt=0.1, t0=0.2)) - end end diff --git a/test/util/scan.jl b/test/util/scan.jl index 7f1d47b9..9340498a 100644 --- a/test/util/scan.jl +++ b/test/util/scan.jl @@ -1,16 +1,3 @@ -using Test -using Zygote: ZygoteRuleConfig -using TemporalGPs: scan_emit -using StructArrays -using ChainRulesTestUtils - @testset "scan" begin - # Run forwards. - x = StructArray([(a=randn(), b=randn()) for _ in 1:10]) - stepper = (x_, y_) -> (x_ + y_.a * y_.b * x_, x_ + y_.b) - # test_rrule(scan_emit, stepper, x, 0.0, eachindex(x)) - - # Run in reverse. - # test_rrule(scan_emit, stepper, x, 0.0, reverse(eachindex(x))) end diff --git a/test/util/zygote_friendly_map.jl b/test/util/zygote_friendly_map.jl index e81c21b4..eb817658 100644 --- a/test/util/zygote_friendly_map.jl +++ b/test/util/zygote_friendly_map.jl @@ -1,6 +1,3 @@ -using FillArrays -using TemporalGPs - @testset "zygote_friendly_map" begin @testset "$name" for (name, f, x) in [ ("Vector{Float64}", x -> sin(x) + cos(x) * exp(x), randn(100)), @@ -13,6 +10,5 @@ using TemporalGPs ), ] @test TemporalGPs.zygote_friendly_map(f, x) ≈ map(f, x) - # adjoint_test(x -> TemporalGPs.zygote_friendly_map(f, x), (x, )) end end From e1170990888e8f5bf6a51a0577551e8b7551adf5 Mon Sep 17 00:00:00 2001 From: Will Tebbutt Date: Thu, 26 Sep 2024 23:25:21 +0100 Subject: [PATCH 05/21] Add Test as test dep --- Project.toml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 1334e9a9..63eb964a 100644 --- a/Project.toml +++ b/Project.toml @@ -36,6 +36,7 @@ julia = "1.6" [extras] BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" +Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["BenchmarkTools", "Mooncake"] +test = ["BenchmarkTools", "Mooncake", "Test"] From d73e1e281703ce5e4efa856764b65a42d765fbaa Mon Sep 17 00:00:00 2001 From: Will Tebbutt Date: Thu, 26 Sep 2024 23:27:09 +0100 Subject: [PATCH 06/21] Fix typo --- test/runtests.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/runtests.jl b/test/runtests.jl index 7e3bfb53..98eb9406 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -14,7 +14,7 @@ include("front_matter.jl") end end - if GROUP == "test models" begin + if GROUP == "test models" @testset "models" begin println("models:") include(joinpath("models", "test_model_test_utils.jl")) @@ -25,7 +25,7 @@ include("front_matter.jl") end end - if GROUP == "test gp" begin + if GROUP == "test gp" println("gp:") @testset "gp" begin include(joinpath("gp", "lti_sde.jl")) @@ -33,7 +33,7 @@ include("front_matter.jl") end end - if GROUP == "test space_time" begin + if GROUP == "test space_time" println("space_time:") @testset "space_time" begin include(joinpath("space_time", "rectilinear_grid.jl")) From e983dbdfc793096ab1084142f6f2525b4358cd06 Mon Sep 17 00:00:00 2001 From: Will Tebbutt Date: Thu, 26 Sep 2024 23:29:57 +0100 Subject: [PATCH 07/21] Add Pkg to examples --- examples/Project.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/Project.toml b/examples/Project.toml index 79bd1bd5..83ce622f 100644 --- a/examples/Project.toml +++ b/examples/Project.toml @@ -5,6 +5,7 @@ KernelFunctions = "ec8451be-7e33-11e9-00cf-bbf324bd1392" Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" Optim = "429524aa-4258-5aef-a3af-852621145aeb" ParameterHandling = "2412ca09-6db7-441c-8e3a-88d5709968c5" +Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c" TemporalGPs = "e155a3c4-0841-43e1-8b83-a0e4f03cc18f" From 5c0e38a900748dacd643f5f0e1b969de4dd7f9c5 Mon Sep 17 00:00:00 2001 From: Will Tebbutt Date: Thu, 26 Sep 2024 23:31:16 +0100 Subject: [PATCH 08/21] Add Pkg to test deps --- Project.toml | 3 ++- examples/Project.toml | 1 - 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/Project.toml b/Project.toml index 63eb964a..0197aee4 100644 --- a/Project.toml +++ b/Project.toml @@ -36,7 +36,8 @@ julia = "1.6" [extras] BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" +Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["BenchmarkTools", "Mooncake", "Test"] +test = ["BenchmarkTools", "Mooncake", "Pkg", "Test"] diff --git a/examples/Project.toml b/examples/Project.toml index 83ce622f..79bd1bd5 100644 --- a/examples/Project.toml +++ b/examples/Project.toml @@ -5,7 +5,6 @@ KernelFunctions = "ec8451be-7e33-11e9-00cf-bbf324bd1392" Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" Optim = "429524aa-4258-5aef-a3af-852621145aeb" ParameterHandling = "2412ca09-6db7-441c-8e3a-88d5709968c5" -Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c" TemporalGPs = "e155a3c4-0841-43e1-8b83-a0e4f03cc18f" From cfc5b107d621dbaaa8ab102bba73cf2b338a32f9 Mon Sep 17 00:00:00 2001 From: Will Tebbutt Date: Thu, 26 Sep 2024 23:33:00 +0100 Subject: [PATCH 09/21] Require Mooncake 0-4-3 --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 0197aee4..31e2efd6 100644 --- a/Project.toml +++ b/Project.toml @@ -28,7 +28,7 @@ Bessels = "0.2.8" BlockDiagonals = "0.1.7" FillArrays = "0.13.0 - 0.13.7, 1" KernelFunctions = "0.9, 0.10.1" -Mooncake = "0.4" +Mooncake = "0.4.3" StaticArrays = "1" StructArrays = "0.5, 0.6" julia = "1.6" From 3f59b1c7da9cde17715e877288298864682972a3 Mon Sep 17 00:00:00 2001 From: Will Tebbutt Date: Thu, 26 Sep 2024 23:40:20 +0100 Subject: [PATCH 10/21] Import more names --- test/front_matter.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/test/front_matter.jl b/test/front_matter.jl index 313baa40..ae385e1c 100644 --- a/test/front_matter.jl +++ b/test/front_matter.jl @@ -21,7 +21,9 @@ using TemporalGPs: fill_in_missings, replace_observation_noise_cov, scan_emit, - transform_model_and_obs + transform_model_and_obs, + RectilinearGrid, + RegularInTime ENV["TESTING"] = "TRUE" From a3e3eb5587b044ba4b8e786e9640b2ec5312719d Mon Sep 17 00:00:00 2001 From: Will Tebbutt Date: Thu, 26 Sep 2024 23:52:18 +0100 Subject: [PATCH 11/21] Remove Mooncake as direct dep --- Project.toml | 1 - 1 file changed, 1 deletion(-) diff --git a/Project.toml b/Project.toml index 31e2efd6..e54428b9 100644 --- a/Project.toml +++ b/Project.toml @@ -10,7 +10,6 @@ BlockDiagonals = "0a1fb500-61f7-11e9-3c65-f5ef3456f9f0" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" KernelFunctions = "ec8451be-7e33-11e9-00cf-bbf324bd1392" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" -Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a" From 69727d4ffb7a62b5dd58fc4f65d4711f31a8e417 Mon Sep 17 00:00:00 2001 From: Will Tebbutt Date: Thu, 26 Sep 2024 23:52:27 +0100 Subject: [PATCH 12/21] Formatting --- src/gp/data_representations.jl | 3 +-- src/gp/lti_sde.jl | 14 +++++++------- 2 files changed, 8 insertions(+), 9 deletions(-) diff --git a/src/gp/data_representations.jl b/src/gp/data_representations.jl index 82d46c66..419a1e90 100644 --- a/src/gp/data_representations.jl +++ b/src/gp/data_representations.jl @@ -51,8 +51,7 @@ and spatio-temporal problems typically have multiple elements of `y` associated single element of `x`. """ function observations_to_time_form( - x::AbstractVector{<:Real}, - y::AbstractVector{<:Union{Missing, Real}}, + x::AbstractVector{<:Real}, y::AbstractVector{<:Union{Missing, Real}} ) return y end diff --git a/src/gp/lti_sde.jl b/src/gp/lti_sde.jl index 2b521b18..dccabac8 100644 --- a/src/gp/lti_sde.jl +++ b/src/gp/lti_sde.jl @@ -121,7 +121,6 @@ function lgssm_components( m = collect(mean_vector(m, t)) # `collect` is needed as there are still issues with Zygote and FillArrays. As, as, Qs, (Hs, hs), x0 = lgssm_components(k, t, storage_type) hs = add_proj_mean(hs, m) - return As, as, Qs, (Hs, hs), x0 end @@ -136,7 +135,9 @@ time_exp(A, t) = exp(A * t) # Generic constructors for base kernels. -function broadcast_components((F, q, H)::Tuple, x0::Gaussian, t::AbstractVector{<:Real}, ::StorageType{T}) where {T} +function broadcast_components( + (F, q, H)::Tuple, x0::Gaussian, t::AbstractVector{<:Real}, ::StorageType{T} +) where {T} P = Symmetric(x0.P) t = vcat([first(t) - 1], t) As = map(Δt -> time_exp(F, T(Δt)), diff(t)) @@ -147,7 +148,9 @@ function broadcast_components((F, q, H)::Tuple, x0::Gaussian, t::AbstractVector{ As, as, Qs, Hs, hs end -function broadcast_components((F, q, H)::Tuple, x0::Gaussian, t::Union{StepRangeLen, RegularSpacing}, ::StorageType{T}) where {T} +function broadcast_components( + (F, q, H)::Tuple, x0::Gaussian, t::Union{StepRangeLen, RegularSpacing}, ::StorageType{T} +) where {T} P = Symmetric(x0.P) A = time_exp(F, T(step(t))) As = Fill(A, length(t)) @@ -316,10 +319,7 @@ function TemporalGPs.to_sde(::ConstantKernel, ::SArrayStorage{T}) where {T<:Real end function TemporalGPs.stationary_distribution(k::ConstantKernel, ::SArrayStorage{T}) where {T<:Real} - return TemporalGPs.Gaussian( - SVector{1, T}(0), - SMatrix{1, 1, T}( T(only(k.c)) ), - ) + return TemporalGPs.Gaussian(SVector{1, T}(0), SMatrix{1, 1, T}(T(only(k.c)))) end # Scaled From 219fe2210a8f42d2b56b62015f8dedba9dd42e0f Mon Sep 17 00:00:00 2001 From: Will Tebbutt Date: Fri, 27 Sep 2024 09:10:10 +0100 Subject: [PATCH 13/21] Formatting --- test/models/lgssm.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/test/models/lgssm.jl b/test/models/lgssm.jl index 891101d8..7bc562f5 100644 --- a/test/models/lgssm.jl +++ b/test/models/lgssm.jl @@ -85,7 +85,6 @@ println("lgssm:") @testset "step_marginals" begin @inferred step_marginals(x, model[1]) test_rule(rng, step_marginals, x, model[1]; is_primitive=false) - end @testset "step_logpdf" begin args = (ordering(model[1]), x, (model[1], y)) From 495e1371014501ba3c47c18ddfa0239c30d1be6a Mon Sep 17 00:00:00 2001 From: Will Tebbutt Date: Fri, 27 Sep 2024 09:50:00 +0100 Subject: [PATCH 14/21] Tidy up + enable all tests --- ext/TemporalGPsMooncakeExt.jl | 15 ++++++++--- test/gp/lti_sde.jl | 47 ++++++++++++++++------------------- 2 files changed, 33 insertions(+), 29 deletions(-) diff --git a/ext/TemporalGPsMooncakeExt.jl b/ext/TemporalGPsMooncakeExt.jl index 4be3a2e9..e87c56e3 100644 --- a/ext/TemporalGPsMooncakeExt.jl +++ b/ext/TemporalGPsMooncakeExt.jl @@ -1,14 +1,21 @@ module TemporalGPsMooncakeExt using Mooncake, TemporalGPs -import Mooncake: rrule!!, CoDual, primal, @is_primitive, zero_fcodual, MinimalCtx +import Mooncake: + rrule!!, + CoDual, + primal, + @is_primitive, + zero_fcodual, + MinimalCtx -@is_primitive MinimalCtx Tuple{typeof(TemporalGPs.time_exp), AbstractMatrix{<:Real}, Real} +@is_primitive MinimalCtx Tuple{typeof(TemporalGPs.time_exp), Matrix{<:Real}, Real} function rrule!!(::CoDual{typeof(TemporalGPs.time_exp)}, A::CoDual, t::CoDual{Float64}) - B_dB = zero_fcodual(TemporalGPs.time_exp(primal(A), primal(t))) + _A = primal(A) + B_dB = zero_fcodual(TemporalGPs.time_exp(_A, primal(t))) B = primal(B_dB) dB = tangent(B_dB) - time_exp_pb(::NoRData) = NoRData(), NoRData(), sum(dB .* (primal(A) * B)) + time_exp_pb(::NoRData) = NoRData(), NoRData(), sum(dB .* (_A * B)) return B_dB, time_exp_pb end diff --git a/test/gp/lti_sde.jl b/test/gp/lti_sde.jl index 4d07965e..7b0066e3 100644 --- a/test/gp/lti_sde.jl +++ b/test/gp/lti_sde.jl @@ -3,8 +3,6 @@ using KernelFunctions: kappa using TemporalGPs: build_lgssm, StorageType, is_of_storage_type, lgssm_components using Test -_logistic(x) = 1 / (1 + exp(-x)) - # Everything is tested once the LGSSM is constructed, so it is sufficient just to ensure # that Zygote can handle construction. function _construction_tester(f_naive::GP, storage::StorageType, σ², t::AbstractVector) @@ -92,51 +90,40 @@ println("lti_sde:") N = 13 kernels = vcat( # Base kernels. - (name="base-Matern12Kernel", val=Matern12Kernel(), to_vec_grad=false), + (name="base-Matern12Kernel", val=Matern12Kernel()), map([Matern32Kernel, Matern52Kernel]) do k - (; name="base-$k", val=k(), to_vec_grad=false) + (; name="base-$k", val=k()) end, # Scaled kernels. map([1e-1, 1.0, 10.0, 100.0]) do σ² - (; name="scaled-σ²=$σ²", val=σ² * Matern32Kernel(), to_vec_grad=false) + (; name="scaled-σ²=$σ²", val=σ² * Matern32Kernel()) end, # Stretched kernels. map([1e-2, 0.1, 1.0, 10.0, 100.0]) do λ - (; name="stretched-λ=$λ", val=Matern32Kernel() ∘ ScaleTransform(λ), to_vec_grad=false) + (; name="stretched-λ=$λ", val=Matern32Kernel() ∘ ScaleTransform(λ)) end, # Approx periodic kernels map([7, 11]) do N - ( - name="approx-periodic-N=$N", - val=ApproxPeriodicKernel{N}(; r=1.0), - to_vec_grad=true, - ) + (name="approx-periodic-N=$N", val=ApproxPeriodicKernel{N}(; r=1.0)) end, - # TEST_TOFIX - # Gradients should be fixed on those composites. - # Error is mostly due do an incompatibility of Tangents - # between Zygote and FiniteDifferences. # Product kernels ( name="prod-Matern12Kernel-Matern32Kernel", val=1.5 * Matern12Kernel() ∘ ScaleTransform(0.1) * Matern32Kernel() ∘ ScaleTransform(1.1), - to_vec_grad=nothing, ), ( name="prod-Matern32Kernel-Matern52Kernel-ConstantKernel", val=3.0 * Matern32Kernel() * Matern52Kernel() * ConstantKernel(), - to_vec_grad=nothing, ), # THIS IS KNOWN NOT TO WORK! # ( # name="prod-(Matern32Kernel + ConstantKernel) * Matern52Kernel", # val=(Matern32Kernel() + ConstantKernel()) * Matern52Kernel(), - # to_vec_grad=nothing, # ), # Summed kernels. @@ -144,32 +131,30 @@ println("lti_sde:") name="sum-Matern12Kernel-Matern32Kernel", val=1.5 * Matern12Kernel() ∘ ScaleTransform(0.1) + 0.3 * Matern32Kernel() ∘ ScaleTransform(1.1), - to_vec_grad=nothing, ), ( name="sum-Matern32Kernel-Matern52Kernel-ConstantKernel", val=2.0 * Matern32Kernel() + 0.5 * Matern52Kernel() + 1.0 * ConstantKernel(), - to_vec_grad=nothing, ), ) # Construct a Gauss-Markov model with either dense storage or static storage. storages = ( (name="dense storage Float64", val=ArrayStorage(Float64)), - # (name="static storage Float64", val=SArrayStorage(Float64)), + (name="static storage Float64", val=SArrayStorage(Float64)), ) # Either regular spacing or irregular spacing in time. ts = ( (name="irregular spacing", val=collect(RegularSpacing(0.0, 0.3, N))), - # (name="regular spacing", val=RegularSpacing(0.0, 0.3, N)), + (name="regular spacing", val=RegularSpacing(0.0, 0.3, N)), ) σ²s = ( (name="homoscedastic noise", val=(0.1,)), - # (name="heteroscedastic noise", val=(rand(rng, N) .+ 1e-1, )), + (name="heteroscedastic noise", val=(rand(rng, N) .+ 1e-1, )), ) means = ( @@ -178,8 +163,8 @@ println("lti_sde:") (name="Custom Mean", val=CustomMean(x -> 2x)), ) - @testset "$(kernel.name), $(m.name), $(storage.name), $(t.name), $(σ².name)" for kernel in - kernels, + @testset "$(kernel.name), $(m.name), $(storage.name), $(t.name), $(σ².name)" for + kernel in kernels, m in means, storage in storages, t in ts, @@ -187,6 +172,12 @@ println("lti_sde:") println("$(kernel.name), $(storage.name), $(m.name), $(t.name), $(σ².name)") + if kernel.val isa TemporalGPs.ApproxPeriodicKernel && + storage.val isa SArrayStorage + @info "skipping because ApproxPeriodicKernel not compatible with SArrayStorage" + continue + end + # Construct Gauss-Markov model. f_naive = GP(m.val, kernel.val) fx_naive = f_naive(collect(t.val), σ².val...) @@ -217,4 +208,10 @@ println("lti_sde:") ) end end + @testset "time_exp AD" begin + test_rule( + Xoshiro(123), t -> TemporalGPs.time_exp([1.0 2.0; 3.0 4.0], t), rand(); + is_primitive=false, + ) + end end From 7e67f65fd601ffea47c67a6f2bb04e8edc85111e Mon Sep 17 00:00:00 2001 From: Will Tebbutt Date: Fri, 27 Sep 2024 11:42:38 +0100 Subject: [PATCH 15/21] Enable all tests --- test/models/lgssm.jl | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/test/models/lgssm.jl b/test/models/lgssm.jl index 7bc562f5..20f026e8 100644 --- a/test/models/lgssm.jl +++ b/test/models/lgssm.jl @@ -35,10 +35,10 @@ println("lgssm:") settings = [ (tv=:time_varying, N=1, Dlat=3, Dobs=2, storage=storages.dense), (tv=:time_varying, N=49, Dlat=3, Dobs=2, storage=storages.dense), - # (tv=:time_invariant, N=49, Dlat=3, Dobs=2, storage=storages.dense), + (tv=:time_invariant, N=49, Dlat=3, Dobs=2, storage=storages.dense), (tv=:time_varying, N=49, Dlat=1, Dobs=1, storage=storages.dense), (tv=:time_varying, N=1, Dlat=3, Dobs=2, storage=storages.static), - # (tv=:time_invariant, N=49, Dlat=3, Dobs=2, storage=storages.static), + (tv=:time_invariant, N=49, Dlat=3, Dobs=2, storage=storages.static), ] orderings = [ Forward(), @@ -46,7 +46,7 @@ println("lgssm:") ] Qs = [ Val(:dense), - # Val(:diag), diag tests don't work because `FiniteDiffernces.to_vec`. + Val(:diag), ] @testset "($tv, $N, $Dlat, $Dobs, $(storage.name), $(emission.name), $order, $Q)" for @@ -82,29 +82,30 @@ println("lgssm:") y = first(rand(model)) x = TemporalGPs.x0(model) + interface_only = true @testset "step_marginals" begin @inferred step_marginals(x, model[1]) - test_rule(rng, step_marginals, x, model[1]; is_primitive=false) + test_rule(rng, step_marginals, x, model[1]; is_primitive=false, interface_only) end @testset "step_logpdf" begin args = (ordering(model[1]), x, (model[1], y)) @inferred step_logpdf(args...) - test_rule(rng, step_logpdf, args...; is_primitive=false) + test_rule(rng, step_logpdf, args...; is_primitive=false, interface_only) end @testset "step_filter" begin args = (ordering(model[1]), x, (model[1], y)) @inferred step_filter(args...) - test_rule(rng, step_filter, args...; is_primitive=false) + test_rule(rng, step_filter, args...; is_primitive=false, interface_only) end @testset "invert_dynamics" begin args = (x, x, model[1].transition) @inferred invert_dynamics(args...) - test_rule(rng, invert_dynamics, args...; is_primitive=false) + test_rule(rng, invert_dynamics, args...; is_primitive=false, interface_only) end @testset "step_posterior" begin args = (ordering(model[1]), x, (model[1], y)) @inferred step_posterior(args...) - test_rule(rng, step_posterior, args...; is_primitive=false) + test_rule(rng, step_posterior, args...; is_primitive=false, interface_only) end # Run standard battery of LGSSM tests. From d95557e342a93dac80c084827e116ea7c82633d0 Mon Sep 17 00:00:00 2001 From: Will Tebbutt Date: Fri, 27 Sep 2024 14:15:58 +0100 Subject: [PATCH 16/21] Add JET as test dep' --- Project.toml | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index e54428b9..d90fe9a3 100644 --- a/Project.toml +++ b/Project.toml @@ -26,6 +26,7 @@ BenchmarkTools = "1" Bessels = "0.2.8" BlockDiagonals = "0.1.7" FillArrays = "0.13.0 - 0.13.7, 1" +JET = "0.9" KernelFunctions = "0.9, 0.10.1" Mooncake = "0.4.3" StaticArrays = "1" @@ -34,9 +35,10 @@ julia = "1.6" [extras] BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" +JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["BenchmarkTools", "Mooncake", "Pkg", "Test"] +test = ["BenchmarkTools", "JET", "Mooncake", "Pkg", "Test"] From 0391be384815fa1240989fda4466a185a43c25ab Mon Sep 17 00:00:00 2001 From: Will Tebbutt Date: Fri, 27 Sep 2024 14:17:57 +0100 Subject: [PATCH 17/21] Tidy up and use JET rather than inferred --- test/front_matter.jl | 27 ++++++++-- test/models/lgssm.jl | 59 +++++---------------- test/models/linear_gaussian_conditionals.jl | 50 ++++++----------- test/models/missings.jl | 4 +- test/test_util.jl | 38 +++++++------ 5 files changed, 73 insertions(+), 105 deletions(-) diff --git a/test/front_matter.jl b/test/front_matter.jl index ae385e1c..3b8d47b7 100644 --- a/test/front_matter.jl +++ b/test/front_matter.jl @@ -2,8 +2,9 @@ using AbstractGPs, BlockDiagonals, FillArrays, - LinearAlgebra, + JET, KernelFunctions, + LinearAlgebra, Mooncake, Random, StaticArrays, @@ -23,7 +24,26 @@ using TemporalGPs: scan_emit, transform_model_and_obs, RectilinearGrid, - RegularInTime + RegularInTime, + posterior_and_lml, + predict, + predict_marginals, + step_marginals, + step_logpdf, + step_filter, + step_rand, + invert_dynamics, + step_posterior, + storage_type, + is_of_storage_type, + ArrayStorage, + SArrayStorage, + SmallOutputLGC, + LargeOutputLGC, + ScalarOutputLGC, + Forward, + Reverse, + ordering ENV["TESTING"] = "TRUE" @@ -34,8 +54,5 @@ ENV["TESTING"] = "TRUE" # ENV["GROUP"] = "test gp" const GROUP = get(ENV, "GROUP", "all") -const TEST_TYPE_INFER = false # Test type stability over the tests -const TEST_ALLOC = false # Test allocations over the tests - include("test_util.jl") include(joinpath("models", "model_test_utils.jl")) diff --git a/test/models/lgssm.jl b/test/models/lgssm.jl index 20f026e8..f8b85ba7 100644 --- a/test/models/lgssm.jl +++ b/test/models/lgssm.jl @@ -1,23 +1,3 @@ -using TemporalGPs: - TemporalGPs, - predict, - step_marginals, - step_logpdf, - step_filter, - step_rand, - invert_dynamics, - step_posterior, - storage_type, - is_of_storage_type, - ArrayStorage, - SArrayStorage, - SmallOutputLGC, - LargeOutputLGC, - ScalarOutputLGC, - Forward, - Reverse, - ordering - println("lgssm:") @testset "lgssm" begin @@ -58,7 +38,8 @@ println("lgssm:") # Print current iteration to prevent CI timing out. println( "(time_varying=$tv, N=$N, Dlat=$Dlat, Dobs=$Dobs, " * - "storage=$(storage.name), emissions=$(emission.val), ordering=$order)", + "storage=$(storage.name), emissions=$(emission.val), ordering=$order, " * + "Q=$Q)", ) # Build LGSSM. @@ -82,30 +63,16 @@ println("lgssm:") y = first(rand(model)) x = TemporalGPs.x0(model) - interface_only = true - @testset "step_marginals" begin - @inferred step_marginals(x, model[1]) - test_rule(rng, step_marginals, x, model[1]; is_primitive=false, interface_only) - end - @testset "step_logpdf" begin - args = (ordering(model[1]), x, (model[1], y)) - @inferred step_logpdf(args...) - test_rule(rng, step_logpdf, args...; is_primitive=false, interface_only) - end - @testset "step_filter" begin - args = (ordering(model[1]), x, (model[1], y)) - @inferred step_filter(args...) - test_rule(rng, step_filter, args...; is_primitive=false, interface_only) - end - @testset "invert_dynamics" begin - args = (x, x, model[1].transition) - @inferred invert_dynamics(args...) - test_rule(rng, invert_dynamics, args...; is_primitive=false, interface_only) - end - @testset "step_posterior" begin - args = (ordering(model[1]), x, (model[1], y)) - @inferred step_posterior(args...) - test_rule(rng, step_posterior, args...; is_primitive=false, interface_only) + perf_flag = storage.val isa SArrayStorage ? :allocs : :none + @testset "$f" for (f, args...) in Any[ + (step_marginals, x, model[1]), + (step_logpdf, ordering(model[1]), x, (model[1], y)), + (step_filter, ordering(model[1]), x, (model[1], y)), + (invert_dynamics, x, x, model[1].transition), + (step_posterior, ordering(model[1]), x, (model[1], y)), + ] + @test_opt target_modules=[TemporalGPs] f(args...) + test_rule(rng, f, args...; is_primitive=false, interface_only=true, perf_flag) end # Run standard battery of LGSSM tests. @@ -116,7 +83,7 @@ println("lgssm:") max_primal_allocs=25, max_forward_allocs=25, max_backward_allocs=25, - check_allocs=TEST_ALLOC && storage.val isa SArrayStorage, + check_allocs=storage.val isa SArrayStorage, ) end end diff --git a/test/models/linear_gaussian_conditionals.jl b/test/models/linear_gaussian_conditionals.jl index 78be0a39..d6fb055f 100644 --- a/test/models/linear_gaussian_conditionals.jl +++ b/test/models/linear_gaussian_conditionals.jl @@ -1,19 +1,9 @@ -using TemporalGPs: posterior_and_lml, predict, predict_marginals -using Test - println("linear_gaussian_conditionals:") @testset "linear_gaussian_conditionals" begin Dlats = [1, 3] Dobss = [1, 2] - # Dlats = [3] - # Dobss = [2] - storages = [ - (name="dense storage Float64", val=ArrayStorage(Float64)), - ] - Q_types = [ - Val(:dense), - Val(:diag), - ] + storages = [(name="dense storage Float64", val=ArrayStorage(Float64))] + Q_types = [Val(:dense), Val(:diag)] @testset "SmallOutputLGC (Dlat=$Dlat, Dobs=$Dobs, Q=$(Q_type), $(storage.name))" for Dlat in Dlats, @@ -27,11 +17,9 @@ println("linear_gaussian_conditionals:") x = random_gaussian(rng, Dlat, storage.val) model = random_small_output_lgc(rng, Dlat, Dobs, Q_type, storage.val) + check_allocs = storage.val isa SArrayStorage test_interface( - rng, model, x; - check_adjoints=true, - check_inferred=TEST_TYPE_INFER, - check_allocs=TEST_ALLOC && storage.val isa SArrayStorage, + rng, model, x; check_adjoints=true, check_inferred=true, check_allocs ) Q_type == Val(:diag) && @testset "missing data" begin @@ -57,9 +45,8 @@ println("linear_gaussian_conditionals:") @test lml ≈ lml_new atol=1e-8 rtol=1e-8 # Check that everything infers and AD gives the right answer. - @inferred posterior_and_lml(x, model, y_missing) - # BROKEN: gradients with Zygote look fine but are failing because of ChainRulesTestUtils checks see https://github.com/JuliaDiff/ChainRulesTestUtils.jl/issues/270 - # test_zygote_grad(posterior_and_lml, x, model, y_missing) + @test_opt target_modules=[TemporalGPs] posterior_and_lml(x, model, y_missing) + test_rule(rng, posterior_and_lml, x, model, y_missing; is_primitive=false) end end @@ -103,27 +90,25 @@ println("linear_gaussian_conditionals:") @test lml_vanilla ≈ lml_large rtol=1e-8 atol=1e-8 # Check that everything infers and AD gives the right answer. - @inferred posterior_and_lml(x, model, y_missing) + @test_opt target_modules=[TemporalGPs] posterior_and_lml(x, model, y_missing) test_rule(rng, posterior_and_lml, x, model, y_missing; is_primitive=false) end end + check_allocs = storage.val isa SArrayStorage test_interface( - rng, model, x; - check_adjoints=true, - check_inferred=TEST_TYPE_INFER, - check_allocs=TEST_ALLOC && storage.val isa SArrayStorage, + rng, model, x; check_adjoints=true, check_inferred=true, check_allocs ) end - @testset "ScalarOutputLGC (Dlat=$Dlat, ($storage.name))" for + @testset "ScalarOutputLGC (Dlat=$Dlat, $(storage.name))" for Dlat in Dlats, storage in [ (name="dense storage Float64", val=ArrayStorage(Float64)), (name="static storage Float64", val=SArrayStorage(Float64)), ] - println("ScalarOutputLGC (Dlat=$Dlat, ($storage.name))") + println("ScalarOutputLGC (Dlat=$Dlat, $(storage.name))") rng = MersenneTwister(123456) x = random_gaussian(rng, Dlat, storage.val) @@ -140,11 +125,9 @@ println("linear_gaussian_conditionals:") @test lml_vanilla ≈ lml_scalar end + check_allocs = storage.val isa SArrayStorage test_interface( - rng, model, x; - check_adjoints=true, - check_inferred=TEST_TYPE_INFER, - check_allocs=TEST_ALLOC && storage.val isa SArrayStorage, + rng, model, x; check_adjoints=true, check_inferred=true, check_allocs ) end @@ -167,10 +150,7 @@ println("linear_gaussian_conditionals:") @test TemporalGPs.dim_in(model) == Din test_interface( - rng, model, x; - check_adjoints=true, - check_inferred=TEST_TYPE_INFER, - check_allocs=TEST_ALLOC, + rng, model, x; check_adjoints=true, check_inferred=true, check_allocs=false ) @testset "consistency with SmallOutputLGC" begin @@ -202,7 +182,7 @@ println("linear_gaussian_conditionals:") @test lml_vanilla ≈ lml_large rtol=1e-8 atol=1e-8 # Check that everything infers and AD gives the right answer. - @inferred posterior_and_lml(x, model, y_missing) + @test_opt target_modules=[TemporalGPs] posterior_and_lml(x, model, y_missing) test_rule(rng, posterior_and_lml, x, model, y_missing; is_primitive=false) end end diff --git a/test/models/missings.jl b/test/models/missings.jl index fb78be70..bc353a1c 100644 --- a/test/models/missings.jl +++ b/test/models/missings.jl @@ -174,7 +174,7 @@ end # Check logpdf and inference run, infer. - @inferred logpdf(model, y_missing) - @inferred posterior(model, y_missing) + @test_opt target_modules=[TemporalGPs] logpdf(model, y_missing) + @test_opt target_modules=[TemporalGPs] posterior(model, y_missing) end end; diff --git a/test/test_util.jl b/test/test_util.jl index f5892efb..e071ec26 100644 --- a/test/test_util.jl +++ b/test/test_util.jl @@ -21,23 +21,24 @@ using TemporalGPs: function test_interface( rng::AbstractRNG, conditional::AbstractLGC, x::Gaussian; - check_inferred=TEST_TYPE_INFER, check_adjoints=true, check_allocs=TEST_ALLOC, kwargs..., + check_inferred=true, check_adjoints=true, check_allocs=true, ) x_val = rand(rng, x) y = conditional_rand(rng, conditional, x_val) perf_flag = check_allocs ? :allocs : :none + is_primitive = false @testset "rand" begin @test length(y) == dim_out(conditional) args = (TemporalGPs.ε_randn(rng, conditional), conditional, x_val) - check_inferred && @inferred conditional_rand(args...) - check_adjoints && test_rule(rng, conditional_rand, args...; perf_flag, is_primitive=false) + check_inferred && @test_opt target_modules=[TemporalGPs] conditional_rand(args...) + check_adjoints && test_rule(rng, conditional_rand, args...; perf_flag, is_primitive) end @testset "predict" begin @test predict(x, conditional) isa Gaussian - check_inferred && @inferred predict(x, conditional) - check_adjoints && test_rule(rng, predict, x, conditional; perf_flag, is_primitive=false) + check_inferred && @test_opt target_modules=[TemporalGPs] predict(x, conditional) + check_adjoints && test_rule(rng, predict, x, conditional; perf_flag, is_primitive) end conditional isa ScalarOutputLGC || @testset "predict_marginals" begin @@ -52,15 +53,15 @@ function test_interface( @testset "posterior_and_lml" begin args = (x, conditional, y) @test posterior_and_lml(args...) isa Tuple{Gaussian, Real} - check_inferred && @inferred posterior_and_lml(args...) - check_adjoints && test_rule(rng, posterior_and_lml, args...; perf_flag, is_primitive=false) + check_inferred && @test_opt target_modules=[TemporalGPs] posterior_and_lml(args...) + check_adjoints && test_rule(rng, posterior_and_lml, args...; perf_flag, is_primitive) end end """ test_interface( rng::AbstractRNG, ssm::AbstractLGSSM; - check_inferred=TEST_TYPE_INFER, check_adjoints=true, check_allocs=TEST_ALLOC, kwargs... + check_inferred=true, check_adjoints=true, check_allocs=true, ) Basic consistency tests that any LGSSM should be able to satisfy. The purpose of these tests @@ -69,7 +70,7 @@ consistent and implements the required interface. """ function test_interface( rng::AbstractRNG, ssm::AbstractLGSSM; - check_inferred=TEST_TYPE_INFER, check_adjoints=true, check_allocs=TEST_ALLOC, rtol, atol, kwargs... + check_inferred=true, check_adjoints=true, check_allocs=true, ) perf_flag = check_allocs ? :allocs : :none y_no_missing = rand(rng, ssm) @@ -78,13 +79,17 @@ function test_interface( @test is_of_storage_type(y_no_missing[1], storage_type(ssm)) @test y_no_missing isa AbstractVector @test length(y_no_missing) == length(ssm) - check_inferred && @inferred rand(rng, ssm) + check_inferred && @test_opt target_modules=[TemporalGPs] rand(rng, ssm) rng = MersenneTwister(123456) - check_adjoints && test_rule(rng, rand, rng, ssm; perf_flag, interface_only=true, is_primitive=false) + if check_adjoints + test_rule( + rng, rand, rng, ssm; perf_flag, interface_only=true, is_primitive=false + ) + end end @testset "basics" begin - @inferred storage_type(ssm) + @test_opt target_modules=[TemporalGPs] storage_type(ssm) @test length(ssm) == length(y_no_missing) end @@ -93,7 +98,7 @@ function test_interface( @test is_of_storage_type(xs, storage_type(ssm)) @test xs isa AbstractVector{<:Gaussian} @test length(xs) == length(ssm) - check_inferred && @inferred marginals(ssm) + check_inferred && @test_opt target_modules=[TemporalGPs] marginals(ssm) if check_adjoints test_rule( rng, scan_emit, step_marginals, ssm, x0(ssm), eachindex(ssm); @@ -104,7 +109,6 @@ function test_interface( @testset "$(data.name)" for data in [ (name="no-missings", y=y_no_missing), - # (name="with-missings", y=y_missing), ] _check_inferred = data.name == "with-missings" ? false : check_inferred @@ -113,7 +117,7 @@ function test_interface( lml = logpdf(ssm, y) @test lml isa Real @test is_of_storage_type(lml, storage_type(ssm)) - _check_inferred && @inferred logpdf(ssm, y) + _check_inferred && @test_opt target_modules=[TemporalGPs] logpdf(ssm, y) if check_adjoints test_rule( rng, scan_emit, step_logpdf, zip(ssm, y), x0(ssm), eachindex(ssm); @@ -126,7 +130,7 @@ function test_interface( @test is_of_storage_type(xs, storage_type(ssm)) @test xs isa AbstractVector{<:Gaussian} @test length(xs) == length(ssm) - _check_inferred && @inferred _filter(ssm, y) + _check_inferred && @test_opt target_modules=[TemporalGPs] _filter(ssm, y) if check_adjoints test_rule( rng, scan_emit, step_filter, zip(ssm, y), x0(ssm), eachindex(ssm); @@ -138,7 +142,7 @@ function test_interface( posterior_ssm = posterior(ssm, y) @test length(posterior_ssm) == length(ssm) @test ordering(posterior_ssm) != ordering(ssm) - _check_inferred && @inferred posterior(ssm, y) + _check_inferred && @test_opt target_modules=[TemporalGPs] posterior(ssm, y) if check_adjoints test_rule( rng, posterior, ssm, y; From 70a8b1d80ba4b594bb7013f2868f6cecec0fc5b1 Mon Sep 17 00:00:00 2001 From: Will Tebbutt Date: Fri, 27 Sep 2024 14:27:50 +0100 Subject: [PATCH 18/21] Some fixes --- test/models/lgssm.jl | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/test/models/lgssm.jl b/test/models/lgssm.jl index f8b85ba7..621c27f2 100644 --- a/test/models/lgssm.jl +++ b/test/models/lgssm.jl @@ -76,14 +76,6 @@ println("lgssm:") end # Run standard battery of LGSSM tests. - test_interface( - rng, model; - rtol=1e-5, - atol=1e-5, - max_primal_allocs=25, - max_forward_allocs=25, - max_backward_allocs=25, - check_allocs=storage.val isa SArrayStorage, - ) + test_interface(rng, model; check_allocs=false) end end From 5bae510b907b9d57ccefa1bdced07f66a010ef73 Mon Sep 17 00:00:00 2001 From: Will Tebbutt Date: Fri, 27 Sep 2024 14:34:50 +0100 Subject: [PATCH 19/21] Discuss the changes in this release --- NEWS.md | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/NEWS.md b/NEWS.md index 37ca39a7..ba958176 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,3 +1,18 @@ +# 0.7 + +Mooncake.jl (and probably Enzyme.jl) is now able to differentiate everything in +TemporalGPs.jl _reasonably_ efficiently, and only requires a single rule (for time_exp). +This is in stark contrast with Zygote.jl, which required roughly 2.5k lines to achieve +reasonable performance. This code was not robust, required maintenance from time-to-time, +and generally made making progress on improvements to this library hard to make. +Consequently, in this version of TemporalGPs, we have removed all Zygote-related +functionality, and now recommend that Mooncake.jl (or perhaps Enzyme.jl) is used to +differentiate code in this package. In some places Mooncake.jl achieves worse performance +than Zygote.jl, but it is worth it for the amount of code that has been removed. + +If you wish to use Zygote + TemporalGPs, you should restrict yourself to the 0.6 series of +this package. + # 0.5.12 - A collection of examples of inference, and inference + learning, have been added. From 84f1d3f1dff1b0f07416fe40ed8c147c74c3d084 Mon Sep 17 00:00:00 2001 From: Will Tebbutt Date: Fri, 27 Sep 2024 17:49:30 +0100 Subject: [PATCH 20/21] Figure out how to avoid bad gradients --- src/gp/lti_sde.jl | 24 ++++++++++++++++++++++++ test/front_matter.jl | 11 +++++++++++ test/gp/lti_sde.jl | 16 +++++++--------- 3 files changed, 42 insertions(+), 9 deletions(-) diff --git a/src/gp/lti_sde.jl b/src/gp/lti_sde.jl index dccabac8..18880d63 100644 --- a/src/gp/lti_sde.jl +++ b/src/gp/lti_sde.jl @@ -187,6 +187,8 @@ function stationary_distribution(k::SimpleKernel, ::ArrayStorage{T}) where {T<:R return Gaussian(collect(x.m), collect(x.P)) end +safe_to_product(::Kernel) = false + # Matern-1/2 function to_sde(::Matern12Kernel, ::SArrayStorage{T}) where {T<:Real} @@ -203,6 +205,8 @@ function stationary_distribution(::Matern12Kernel, ::SArrayStorage{T}) where {T< ) end +safe_to_product(::Matern12Kernel) = true + # Matern - 3/2 function to_sde(::Matern32Kernel, ::SArrayStorage{T}) where {T<:Real} @@ -220,6 +224,8 @@ function stationary_distribution(::Matern32Kernel, ::SArrayStorage{T}) where {T< ) end +safe_to_product(::Matern32Kernel) = true + # Matern - 5/2 function to_sde(::Matern52Kernel, ::SArrayStorage{T}) where {T<:Real} @@ -237,6 +243,8 @@ function stationary_distribution(::Matern52Kernel, ::SArrayStorage{T}) where {T< return Gaussian(m, P) end +safe_to_product(::Matern52Kernel) = true + # Cosine function to_sde(::CosineKernel, ::SArrayStorage{T}) where {T} @@ -252,6 +260,8 @@ function stationary_distribution(::CosineKernel, ::SArrayStorage{T}) where {T<:R return Gaussian(m, P) end +safe_to_product(::CosineKernel) = true + # ApproxPeriodicKernel # The periodic kernel is approximated by a sum of cosine kernels with different frequencies. @@ -309,6 +319,8 @@ function stationary_distribution(kernel::ApproxPeriodicKernel{N}, storage::Array return Gaussian(m, P) end +safe_to_product(::ApproxPeriodicKernel) = true + # Constant function TemporalGPs.to_sde(::ConstantKernel, ::SArrayStorage{T}) where {T<:Real} @@ -322,6 +334,9 @@ function TemporalGPs.stationary_distribution(k::ConstantKernel, ::SArrayStorage{ return TemporalGPs.Gaussian(SVector{1, T}(0), SMatrix{1, 1, T}(T(only(k.c)))) end +safe_to_product(::ConstantKernel) = true + + # Scaled function to_sde(k::ScaledKernel, storage::StorageType{T}) where {T<:Real} @@ -334,6 +349,8 @@ function stationary_distribution(k::ScaledKernel, storage::StorageType) return stationary_distribution(k.kernel, storage) end +safe_to_product(k::ScaledKernel) = safe_to_product(k.kernel) + function lgssm_components(k::ScaledKernel, ts::AbstractVector, storage_type::StorageType) As, as, Qs, emission_proj, x0 = lgssm_components(k.kernel, ts, storage_type) σ = sqrt(convert(eltype(storage_type), only(k.σ²))) @@ -361,6 +378,8 @@ function stationary_distribution( return stationary_distribution(k.kernel, storage) end +safe_to_product(::TransformedKernel{<:Kernel, <:ScaleTransform}) = false + function lgssm_components( k::TransformedKernel{<:Kernel, <:ScaleTransform}, ts::AbstractVector, @@ -377,7 +396,12 @@ apply_stretch(a, ts::RegularSpacing) = RegularSpacing(a * ts.t0, a * ts.Δt, ts. # Product +safe_to_product(k::KernelProduct) = all(safe_to_product, k.kernels) + function lgssm_components(k::KernelProduct, ts::AbstractVector, storage::StorageType) + + safe_to_product(k) || throw(ArgumentError("Not all kernels in k are safe to product.")) + sde_kernels = to_sde.(k.kernels, Ref(storage)) F_kernels = getindex.(sde_kernels, 1) F = foldl(_kron_add, F_kernels) diff --git a/test/front_matter.jl b/test/front_matter.jl index 3b8d47b7..8f82c4c6 100644 --- a/test/front_matter.jl +++ b/test/front_matter.jl @@ -54,5 +54,16 @@ ENV["TESTING"] = "TRUE" # ENV["GROUP"] = "test gp" const GROUP = get(ENV, "GROUP", "all") +# Some test-local type piracy. ConstantKernel doesn't have a default constructor, so +# Mooncake's testing functionality doesn't work with it properly. To resolve this, I just +# add a default-style constructor here. +@eval function KernelFunctions.ConstantKernel{P}(c::Vector{P}) where {P<:Real} + $(Expr(:new, :(ConstantKernel{P}), :c)) +end + +@eval function PeriodicKernel{P}(c::Vector{P}) where {P<:Real} + $(Expr(:new, :(PeriodicKernel{P}), :c)) +end + include("test_util.jl") include(joinpath("models", "model_test_utils.jl")) diff --git a/test/gp/lti_sde.jl b/test/gp/lti_sde.jl index 7b0066e3..1a4ef2f9 100644 --- a/test/gp/lti_sde.jl +++ b/test/gp/lti_sde.jl @@ -5,10 +5,9 @@ using Test # Everything is tested once the LGSSM is constructed, so it is sufficient just to ensure # that Zygote can handle construction. -function _construction_tester(f_naive::GP, storage::StorageType, σ², t::AbstractVector) +function _logpdf_tester(f_naive::GP, y, storage::StorageType, σ², t::AbstractVector) f = to_sde(f_naive, storage) - fx = f(t, σ²...) - return build_lgssm(fx) + return logpdf(f(t, σ²...), y) end println("lti_sde:") @@ -112,15 +111,14 @@ println("lti_sde:") # Product kernels ( - name="prod-Matern12Kernel-Matern32Kernel", - val=1.5 * Matern12Kernel() ∘ ScaleTransform(0.1) * Matern32Kernel() ∘ - ScaleTransform(1.1), + name="prod-Matern52Kernel-Matern32Kernel", + val=(1.5 * Matern52Kernel() * Matern32Kernel()) ∘ ScaleTransform(0.01), ), ( name="prod-Matern32Kernel-Matern52Kernel-ConstantKernel", val=3.0 * Matern32Kernel() * Matern52Kernel() * ConstantKernel(), ), - # THIS IS KNOWN NOT TO WORK! + # This is known not to work at all (not a gradient problem). # ( # name="prod-(Matern32Kernel + ConstantKernel) * Matern52Kernel", # val=(Matern32Kernel() + ConstantKernel()) * Matern52Kernel(), @@ -203,8 +201,8 @@ println("lti_sde:") end test_rule( - rng, _construction_tester, f_naive, storage.val, σ².val, t.val; - is_primitive=false, interface_only=true, + rng, _logpdf_tester, f_naive, y, storage.val, σ².val, t.val; + is_primitive=false, ) end end From 1a7edeeb7a4f62907062e57d1797cb12ea872518 Mon Sep 17 00:00:00 2001 From: Will Tebbutt Date: Fri, 27 Sep 2024 18:54:44 +0100 Subject: [PATCH 21/21] Tidy up example --- examples/approx_space_time_learning.jl | 49 -------------------------- 1 file changed, 49 deletions(-) diff --git a/examples/approx_space_time_learning.jl b/examples/approx_space_time_learning.jl index b05c88b9..d6d7aff1 100644 --- a/examples/approx_space_time_learning.jl +++ b/examples/approx_space_time_learning.jl @@ -54,61 +54,12 @@ y = sin.(first.(xs)) .+ cos.(last.(xs)) + sqrt.(params.var_noise) .* randn(lengt # Spatial pseudo-point inputs. z_r = collect(range(-3.0, 3.0; length=5)); -# # Specify an objective function for Optim to minimise in terms of x and y. -# # We choose the usual negative log marginal likelihood (NLML). -# function make_objective(unpack, x, y, z_r) -# function objective(flat_params) -# params = unpack(flat_params) -# f = build_gp(params) -# return elbo(f(x, params.var_noise), y, z_r) -# end -# return objective -# end -# objective = make_objective(unpack, x, y, z_r) - function objective(flat_params) params = unpack(flat_params) f = build_gp(params) return -elbo(f(x, params.var_noise), y, z_r) end -# using Random -# # y = y -# # z_r = z_r -# # fx = build_gp(unpack(flat_initial_params))(x, params.var_noise) -# # fx_dtc = TemporalGPs.dtcify(z_r, fx) -# # lgssm = TemporalGPs.build_lgssm(fx_dtc) -# # Σs = lgssm.emissions.fan_out.Q -# # marg_diags = TemporalGPs.marginals_diag(lgssm) - -# # k = fx_dtc.f.f.kernel -# # Cf_diags = TemporalGPs.kernel_diagonals(k, fx_dtc.x) - -# # # Transform a vector into a vector-of-vectors. -# # y_vecs = TemporalGPs.restructure(y, lgssm.emissions) - -# # tmp = TemporalGPs.zygote_friendly_map( -# # ((Σ, Cf_diag, marg_diag, yn), ) -> begin -# # Σ_, _ = TemporalGPs.fill_in_missings(Σ, yn) -# # return sum(TemporalGPs.diag(Σ_ \ (Cf_diag - marg_diag.P))) - -# # count(ismissing, yn) + size(Σ_, 1) -# # end, -# # zip(Σs, Cf_diags, marg_diags, y_vecs), -# # ) - -# # logpdf(lgssm, y_vecs) # this is the failing thing - -# # for _ in 1:10 -# # Tapir.TestUtils.test_rule( -# # Xoshiro(123456), objective, flat_initial_params; -# # perf_flag=:none, -# # interp=Tapir.TapirInterpreter(), -# # interface_only=false, -# # is_primitive=false, -# # safety_on=false, -# # ) -# # end - # Optimise using Optim. function objective_grad(rule, flat_params) return Mooncake.value_and_gradient!!(rule, objective, flat_params)[2][2]