Skip to content

Commit

Permalink
Resolve merge
Browse files Browse the repository at this point in the history
  • Loading branch information
willtebbutt committed Sep 27, 2024
2 parents b3405d0 + 1a7edee commit 388e4eb
Showing 1 changed file with 0 additions and 49 deletions.
49 changes: 0 additions & 49 deletions examples/approx_space_time_learning.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,61 +54,12 @@ y = sin.(first.(xs)) .+ cos.(last.(xs)) + sqrt.(params.var_noise) .* randn(lengt
# Spatial pseudo-point inputs.
z_r = collect(range(-3.0, 3.0; length=5));

# # Specify an objective function for Optim to minimise in terms of x and y.
# # We choose the usual negative log marginal likelihood (NLML).
# function make_objective(unpack, x, y, z_r)
# function objective(flat_params)
# params = unpack(flat_params)
# f = build_gp(params)
# return elbo(f(x, params.var_noise), y, z_r)
# end
# return objective
# end
# objective = make_objective(unpack, x, y, z_r)

function objective(flat_params)
params = unpack(flat_params)
f = build_gp(params)
return -elbo(f(x, params.var_noise), y, z_r)
end

# using Random
# # y = y
# # z_r = z_r
# # fx = build_gp(unpack(flat_initial_params))(x, params.var_noise)
# # fx_dtc = TemporalGPs.dtcify(z_r, fx)
# # lgssm = TemporalGPs.build_lgssm(fx_dtc)
# # Σs = lgssm.emissions.fan_out.Q
# # marg_diags = TemporalGPs.marginals_diag(lgssm)

# # k = fx_dtc.f.f.kernel
# # Cf_diags = TemporalGPs.kernel_diagonals(k, fx_dtc.x)

# # # Transform a vector into a vector-of-vectors.
# # y_vecs = TemporalGPs.restructure(y, lgssm.emissions)

# # tmp = TemporalGPs.zygote_friendly_map(
# # ((Σ, Cf_diag, marg_diag, yn), ) -> begin
# # Σ_, _ = TemporalGPs.fill_in_missings(Σ, yn)
# # return sum(TemporalGPs.diag(Σ_ \ (Cf_diag - marg_diag.P))) -
# # count(ismissing, yn) + size(Σ_, 1)
# # end,
# # zip(Σs, Cf_diags, marg_diags, y_vecs),
# # )

# # logpdf(lgssm, y_vecs) # this is the failing thing

# # for _ in 1:10
# # Tapir.TestUtils.test_rule(
# # Xoshiro(123456), objective, flat_initial_params;
# # perf_flag=:none,
# # interp=Tapir.TapirInterpreter(),
# # interface_only=false,
# # is_primitive=false,
# # safety_on=false,
# # )
# # end

# Optimise using Optim.
function objective_grad(rule, flat_params)
return Mooncake.value_and_gradient!!(rule, objective, flat_params)[2][2]
Expand Down

0 comments on commit 388e4eb

Please sign in to comment.