Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fit GP prior and fix bugs #88

Merged
merged 2 commits into from
Jun 30, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 122 additions & 8 deletions src/algorithms/bayesian.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,42 +2,146 @@ using AbstractGPs, IntervalArithmetic

mutable struct ZeroOrderGPSurrogate <: Function
f::Function
gp::GP
gps::Vector{<:GP}
X::Vector{Vector{Float64}}
y::Union{Vector{Vector{Float64}}, Vector{Float64}}
noise::Float64
std_multiple::Float64
mode::Symbol
N::Int
every::Int
skip::Int
last::Union{Symbol, Int}
fit_prior::Bool
fit_noise::Bool
counter::Int
end
function ZeroOrderGPSurrogate(
f, x; kernel = SqExponentialKernel(),
X = [x], y = [f(x)],
noise = 1e-8, std_multiple = 3.0, mode = :interval,
f, x; kernel = SqExponentialKernel(), X = [x], y = [f(x)], noise = 1e-8,
std_multiple = 3.0, mode = :interval, every = 5, skip = 10, last = :all,
fit_prior = true, fit_noise = false,
)
@assert noise > 0
gp = GP(kernel)
gps = [GP(AbstractGPs.ConstMean(0.0), kernel) for _ in 1:length(y[1])]
N = length(y[1])
return ZeroOrderGPSurrogate(
f, gp, X, y, noise, std_multiple, mode, N,
f, gps, X, y, noise, std_multiple, mode, N, every, skip, last,
fit_prior, fit_noise, 0,
)
end

function get_x_lb_ub(k::Kernel)
x, un = flatten(k)
lb = fill(-Inf, length(x))
ub = fill(Inf, length(x))
return x, lb, ub, un
end

function fit_mle!(s, gps, X, y, noise, last, fit_noise)
xs_uns_1 = flatten.(getproperty.(gps, :mean))
xs_lbs_ubs_uns_2 = get_x_lb_ub.(getproperty.(gps, :kernel))

x1 = mapreduce(vcat, xs_uns_1) do x_un
x_un[1]
end
x2 = mapreduce(vcat, xs_lbs_ubs_uns_2) do x_lb_ub_un
x_lb_ub_un[1]
end
lb2 = mapreduce(vcat, xs_lbs_ubs_uns_2) do x_lb_ub_un
x_lb_ub_un[2]
end
ub2 = mapreduce(vcat, xs_lbs_ubs_uns_2) do x_lb_ub_un
x_lb_ub_un[3]
end
if fit_noise
x = [x1; x2; noise]
lb = [fill(-Inf, length(x1)); lb2; 1e-8]
ub = [fill(Inf, length(x1)); ub2; Inf]
else
x = [x1; x2]
lb = [fill(-Inf, length(x1)); lb2]
ub = [fill(Inf, length(x1)); ub2]
end
if last == :all
_X = X
_y = y
else
_X = X[max(end-last+1, 1):end]
_y = y[max(end-last+1, 1):end]
end
obj(θ) = begin
if _y isa Vector{<:Vector}
offset = 0
return -sum(map(1:length(_y[1])) do i
l1 = length(xs_uns_1[i][1])
un1 = xs_uns_1[i][2]
l2 = length(xs_lbs_ubs_uns_2[i][1])
un2 = xs_lbs_ubs_uns_2[i][4]
_gp = GP(un1(θ[offset+1:offset+l1]), un2(θ[offset+l1+1:offset+l1+l2]))
offset += l1 + l2
if fit_noise
return logpdf(_gp(_X, θ[end]), getindex.(_y, i))
else
return logpdf(_gp(_X, noise), getindex.(_y, i))
end
end)
else
l1 = length(xs_uns_1[1][1])
un1 = xs_uns_1[1][2]
l2 = length(xs_lbs_ubs_uns_2[1][1])
un2 = xs_lbs_ubs_uns_2[1][4]
_gp = GP(un1(θ[1:l1]), un2(θ[l1+1:l1+l2]))
if fit_noise
return -logpdf(_gp(_X, θ[end]), _y)
else
return -logpdf(_gp(_X, noise), _y)
end
end
end
m = Model(obj)
addvar!(m, lb, ub)
options = IpoptOptions(max_iter = 20, print_level = 0)
res = optimize(m, IpoptAlg(), x, options = options)
x = res.minimizer
offset = 0
gps, noise = map(1:length(_y[1])) do i
l1 = length(xs_uns_1[i][1])
un1 = xs_uns_1[i][2]
l2 = length(xs_lbs_ubs_uns_2[i][1])
un2 = xs_lbs_ubs_uns_2[i][4]
out = GP(un1(x[offset+1:offset+l1]), un2(x[offset+l1+1:offset+l1+l2]))
offset += l1 + l2
return out
end, fit_noise ? x[end] : noise
s.gps = gps
s.noise = noise
return s
end
function ChainRulesCore.rrule(::typeof(fit_mle!), s, gp, X, y, noise, last, fit_noise)
return fit_mle!(s, gp, X, y, noise, last, fit_noise), _ -> begin
return ntuple(_ -> NoTangent(), Val(8))
end
end

function (s::ZeroOrderGPSurrogate)(x)
if s.mode == :exact
y = s.f(x)
s.X = vcat(s.X, [x])
s.y = vcat(s.y, [y])
s.counter += 1
if s.fit_prior && s.counter > s.skip * s.every && (s.counter % s.every) == 0
fit_mle!(s, s.gps, s.X, s.y, s.noise, s.last, s.fit_noise)
end
return Interval.(y, y)
else
if eltype(s.y) <: Real
_m, _v = mean_and_var(posterior(
s.gp(s.X, s.noise), s.y,
s.gps[1](s.X, s.noise), s.y,
), [x])
m, v = _m[1], _v[1]
else
_gp = s.gp(s.X, s.noise)
ms_vs = map(1:s.N) do i
_gp = s.gps[i](s.X, s.noise)
mean_and_var(posterior(_gp, getindex.(s.y, i)), [x])
end
m = reduce(vcat, getindex.(ms_vs, 1))
Expand Down Expand Up @@ -146,6 +250,11 @@ function BayesOptOptions(;
ctol = 1e-4,
ftol = 1e-4,
postoptimize = false,
fit_prior = true,
fit_noise = false,
every = 2,
skip = 2,
last = :all,
)
return BayesOptOptions(
sub_options,
Expand All @@ -159,6 +268,11 @@ function BayesOptOptions(;
std_multiple = std_multiple,
kernel = kernel,
noise = noise,
every = every,
skip = skip,
last = last,
fit_prior = fit_prior,
fit_noise = fit_noise,
),
)
end
Expand Down
7 changes: 3 additions & 4 deletions src/wrappers/ipopt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -143,11 +143,11 @@ function get_ipopt_problem(model::VecModel, x0::AbstractVector, first_order::Boo
), obj.counter
end
function get_ipopt_problem(obj, ineq_constr, eq_constr, x0, xlb, xub, first_order, linear)
nvars = 0
nvars = length(x0)
if ineq_constr !== nothing
ineqJ0 = Zygote.jacobian(ineq_constr, x0)[1]
ineqJ0 = linear ? sparse(ineqJ0) : ineqJ0
ineq_nconstr, nvars = size(ineqJ0)
ineq_nconstr, _ = size(ineqJ0)
Joffset = nvalues(ineqJ0)
else
ineqJ0 = nothing
Expand All @@ -157,12 +157,11 @@ function get_ipopt_problem(obj, ineq_constr, eq_constr, x0, xlb, xub, first_orde
if eq_constr !== nothing
eqJ0 = Zygote.jacobian(eq_constr, x0)[1]
eqJ0 = linear ? sparse(eqJ0) : eqJ0
eq_nconstr, nvars = size(eqJ0)
eq_nconstr, _ = size(eqJ0)
else
eqJ0 = nothing
eq_nconstr = 0
end
@assert nvars > 0
lag(factor, y) = x -> begin
return factor * obj(x) +
_dot(ineq_constr, x, @view(y[1:ineq_nconstr])) +
Expand Down
Loading