diff --git a/examples/approx_space_time_learning.jl b/examples/approx_space_time_learning.jl index b05c88b..d6d7aff 100644 --- a/examples/approx_space_time_learning.jl +++ b/examples/approx_space_time_learning.jl @@ -54,61 +54,12 @@ y = sin.(first.(xs)) .+ cos.(last.(xs)) + sqrt.(params.var_noise) .* randn(lengt # Spatial pseudo-point inputs. z_r = collect(range(-3.0, 3.0; length=5)); -# # Specify an objective function for Optim to minimise in terms of x and y. -# # We choose the usual negative log marginal likelihood (NLML). -# function make_objective(unpack, x, y, z_r) -# function objective(flat_params) -# params = unpack(flat_params) -# f = build_gp(params) -# return elbo(f(x, params.var_noise), y, z_r) -# end -# return objective -# end -# objective = make_objective(unpack, x, y, z_r) - function objective(flat_params) params = unpack(flat_params) f = build_gp(params) return -elbo(f(x, params.var_noise), y, z_r) end -# using Random -# # y = y -# # z_r = z_r -# # fx = build_gp(unpack(flat_initial_params))(x, params.var_noise) -# # fx_dtc = TemporalGPs.dtcify(z_r, fx) -# # lgssm = TemporalGPs.build_lgssm(fx_dtc) -# # Σs = lgssm.emissions.fan_out.Q -# # marg_diags = TemporalGPs.marginals_diag(lgssm) - -# # k = fx_dtc.f.f.kernel -# # Cf_diags = TemporalGPs.kernel_diagonals(k, fx_dtc.x) - -# # # Transform a vector into a vector-of-vectors. -# # y_vecs = TemporalGPs.restructure(y, lgssm.emissions) - -# # tmp = TemporalGPs.zygote_friendly_map( -# # ((Σ, Cf_diag, marg_diag, yn), ) -> begin -# # Σ_, _ = TemporalGPs.fill_in_missings(Σ, yn) -# # return sum(TemporalGPs.diag(Σ_ \ (Cf_diag - marg_diag.P))) - -# # count(ismissing, yn) + size(Σ_, 1) -# # end, -# # zip(Σs, Cf_diags, marg_diags, y_vecs), -# # ) - -# # logpdf(lgssm, y_vecs) # this is the failing thing - -# # for _ in 1:10 -# # Tapir.TestUtils.test_rule( -# # Xoshiro(123456), objective, flat_initial_params; -# # perf_flag=:none, -# # interp=Tapir.TapirInterpreter(), -# # interface_only=false, -# # is_primitive=false, -# # safety_on=false, -# # ) -# # end - # Optimise using Optim. function objective_grad(rule, flat_params) return Mooncake.value_and_gradient!!(rule, objective, flat_params)[2][2]