Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enabling RODE handling #585

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
287 changes: 199 additions & 88 deletions src/rode_solve.jl
Original file line number Diff line number Diff line change
@@ -1,38 +1,191 @@
struct NNRODE{C, W, O, P, K} <: NeuralPDEAlgorithm
struct NNRODE{C, W, O, P, B, K, S <: Union{Nothing, AbstractTrainingStrategy}} <:
NeuralPDEAlgorithm
chain::C
W::W
opt::O
init_params::P
autodiff::Bool
batch::B
strategy::S
kwargs::K
end
function NNRODE(chain, W, opt, init_params = nothing;
strategy = nothing,
autodiff = false, batch = nothing, kwargs...)
NNRODE(chain, W, opt, init_params, autodiff, batch, strategy, kwargs)
end

mutable struct RODEPhi{C, T, U, S}
chain::C
W::W
opt::O
init_params::P
autodiff::Bool
kwargs::K
end
function NNRODE(chain, W, opt = Optim.BFGS(), init_params = nothing; autodiff = false,
kwargs...)
if init_params === nothing
if chain isa Flux.Chain
init_params, re = Flux.destructure(chain)
else
error("Only Flux is support here right now")
end
t0::T
u0::U
st::S

function RODEPhi(chain::Lux.AbstractExplicitLayer, t::Number, u0, st)
new{typeof(chain), typeof(t), typeof(u0), typeof(st)}(chain, t, u0, st)
end

function RODEPhi(re::Optimisers.Restructure, t, u0)
new{typeof(re), typeof(t), typeof(u0), Nothing}(re, t, u0, nothing)
end
end

function generate_phi_θ_rode(chain::Lux.AbstractExplicitLayer, t, u0, init_params::Nothing)
θ, st = Lux.setup(Random.default_rng(), chain)
RODEPhi(chain, t, u0, st), ComponentArrays.ComponentArray(θ)
end

function generate_phi_θ_rode(chain::Lux.AbstractExplicitLayer, t, u0, init_params)
θ, st = Lux.setup(Random.default_rng(), chain)
RODEPhi(chain, t, u0, st), ComponentArrays.ComponentArray(init_params)
end

function generate_phi_θ_rode(chain::Flux.Chain, t, u0, init_params::Nothing)
θ, re = Flux.destructure(chain)
RODEPhi(re, t, u0), θ
end

function generate_phi_θ_rode(chain::Flux.Chain, t, u0, init_params)
θ, re = Flux.destructure(chain)
RODEPhi(re, t, u0), init_params
end

function (f::RODEPhi{C, T, U})(t::Number, W::Number,
θ) where {C <: Lux.AbstractExplicitLayer, T, U <: Number}
y, st = f.chain(adapt(parameterless_type(θ), [t ; W]), θ, f.st)
ChainRulesCore.@ignore_derivatives f.st = st
f.u0 + (t - f.t0) * first(y)
end

function (f::RODEPhi{C, T, U})(t::AbstractVector, W::AbstractVector,
θ) where {C <: Lux.AbstractExplicitLayer, T, U <: Number}
# Batch via data as row vectors
y, st = f.chain(adapt(parameterless_type(θ), [t W]'), θ, f.st)
ChainRulesCore.@ignore_derivatives f.st = st
f.u0 .+ (t' .- f.t0) .* y
end

function (f::RODEPhi{C, T, U})(t::Number, W::Number, θ) where {C <: Lux.AbstractExplicitLayer, T, U}
y, st = f.chain(adapt(parameterless_type(θ), [t W]), θ, f.st)
ChainRulesCore.@ignore_derivatives f.st = st
f.u0 .+ (t .- f.t0) .* y
end

function (f::RODEPhi{C, T, U})(t::AbstractVector, W::AbstractVector,
θ) where {C <: Lux.AbstractExplicitLayer, T, U}
# Batch via data as row vectors
y, st = f.chain(adapt(parameterless_type(θ), [t W]'), θ, f.st)
ChainRulesCore.@ignore_derivatives f.st = st
f.u0 .+ (t' .- f.t0) .* y
end


function (f::RODEPhi{C, T, U})(t::Number, w::Number,
θ) where {C <: Optimisers.Restructure, T, U <: Number}
f.u0 + (t - f.t0) * first(f.chain(θ)(adapt(parameterless_type(θ), [t, w])))
end

function (f::RODEPhi{C, T, U})(t::AbstractVector, W::AbstractVector,
θ) where {C <: Optimisers.Restructure, T, U <: Number}
f.u0 .+ (t' .- f.t0) .* f.chain(θ)(adapt(parameterless_type(θ), [t W]'))
end

function (f::RODEPhi{C, T, U})(t::Number, w::Number, θ) where {C <: Optimisers.Restructure, T, U}
f.u0 + (t - f.t0) * f.chain(θ)(adapt(parameterless_type(θ), [t]))
end

function (f::RODEPhi{C, T, U})(t::AbstractVector, w::AbstractVector,
θ) where {C <: Optimisers.Restructure, T, U}
f.u0 .+ (t .- f.t0) .* f.chain(θ)(adapt(parameterless_type(θ), [t, W]'))
end

function rode_dfdx end

function rode_dfdx(phi::RODEPhi{C, T, U}, t::Number, W::Number, θ,
autodiff::Bool) where {C, T, U <: Number}
if autodiff
ForwardDiff.derivative(t -> phi(t, W, θ), t)
else
init_params = init_params
(phi(t + sqrt(eps(typeof(t))), W, θ) - phi(t, W, θ)) / sqrt(eps(typeof(t)))
end
NNRODE(chain, W, opt, init_params, autodiff, kwargs)
end

function DiffEqBase.solve(prob::DiffEqBase.AbstractRODEProblem,
alg::NeuralPDEAlgorithm,
args...;
dt,
timeseries_errors = true,
save_everystep = true,
adaptive = false,
abstol = 1.0f-6,
verbose = false,
maxiters = 100)
DiffEqBase.isinplace(prob) && error("Only out-of-place methods are allowed!")
function rode_dfdx(phi::RODEPhi{C, T, U}, t::Number, W::Number, θ,
autodiff::Bool) where {C, T, U <: AbstractVector}
if autodiff
ForwardDiff.jacobian(t -> phi(t, W, θ), t)
else
(phi(t + sqrt(eps(typeof(t))), θ) - phi(t, W, θ)) / sqrt(eps(typeof(t)))
end
end

function rode_dfdx(phi::RODEPhi, t::AbstractVector, W::AbstractVector, θ, autodiff::Bool)
if autodiff
ForwardDiff.jacobian(t -> phi(t, W, θ), t)
else
(phi(t .+ sqrt(eps(eltype(t))), W, θ) - phi(t, W, θ)) ./ sqrt(eps(eltype(t)))
end
end

function inner_loss end

function inner_loss(phi::RODEPhi{C, T, U}, f, autodiff::Bool, t::Number, W::Number, θ,
p) where {C, T, U <: Number}
sum(abs2, rode_dfdx(phi, t, W, θ, autodiff) - f(phi(t, W, θ), p, t, W))
end

function inner_loss(phi::RODEPhi{C, T, U}, f, autodiff::Bool, t::AbstractVector, W::AbstractVector, θ,
p) where {C, T, U <: Number}
out = phi(t, W, θ)
fs = reduce(hcat, [f(out[i], p, t[i], W[i]) for i in 1:size(out, 2)])
dxdtguess = Array(rode_dfdx(phi, t, W, θ, autodiff))
sum(abs2, dxdtguess .- fs) / length(t)
end

function inner_loss(phi::RODEPhi{C, T, U}, f, autodiff::Bool, t::Number, W::Number, θ,
p) where {C, T, U}
sum(abs2, rode_dfdx(phi, t, W, θ, autodiff) .- f(phi(t, W, θ), p, t, W))
end

function inner_loss(phi::RODEPhi{C, T, U}, f, autodiff::Bool, t::AbstractVector, W::AbstractVector, θ,
p) where {C, T, U}
out = Array(phi(t, W, θ))
arrt = Array(t)
fs = reduce(hcat, [f(out[:, i], p, arrt[i], W[i]) for i in 1:size(out, 2)])
dxdtguess = Array(rode_dfdx(phi, t, W, θ, autodiff))
sum(abs2, dxdtguess .- fs) / length(t)
end

function generate_loss(strategy::GridTraining, phi, f, autodiff::Bool, tspan, W, p, batch)
ts = tspan[1]:(strategy.dx):tspan[2]

# sum(abs2,inner_loss(t,θ) for t in ts) but Zygote generators are broken
println(typeof(W))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
println(typeof(W))

function loss(θ, _)
if batch
sum(abs2, [inner_loss(phi, f, autodiff, ts, W[j, :], θ, p) for j in 1:size(W)[1]])
else
sum(abs2, [sum(abs2, [inner_loss(phi, f, autodiff, t, W[j, :][i], θ, p) for (i, t) in enumerate(ts)]) for j in 1:size(W)[1]])
end
end
optf = OptimizationFunction(loss, Optimization.AutoZygote())
end

function DiffEqBase.__solve(prob::DiffEqBase.AbstractRODEProblem,
alg::NNRODE,
args...;
dt = nothing,
trajectories = 100,
timeseries_errors = true,
save_everystep = true,
adaptive = false,
abstol = 1.0f-6,
reltol = 1.0f-3,
verbose = false,
saveat = nothing,
maxiters = nothing)
u0 = prob.u0
W = alg.W
tspan = prob.tspan
f = prob.f
p = prob.p
Expand All @@ -42,75 +195,33 @@ function DiffEqBase.solve(prob::DiffEqBase.AbstractRODEProblem,
chain = alg.chain
opt = alg.opt
autodiff = alg.autodiff
Wg = alg.W

#train points generation
ts = tspan[1]:dt:tspan[2]
init_params = alg.init_params

if chain isa FastChain
#The phi trial solution
if u0 isa Number
phi = (t, W, θ) -> u0 +
(t - tspan[1]) *
first(chain(adapt(DiffEqBase.parameterless_type(θ), [t, W]),
θ))
else
phi = (t, W, θ) -> u0 +
(t - tspan[1]) *
chain(adapt(DiffEqBase.parameterless_type(θ), [t, W]), θ)
end
else
_, re = Flux.destructure(chain)
#The phi trial solution
if u0 isa Number
phi = (t, W, θ) -> u0 +
(t - t0) *
first(re(θ)(adapt(DiffEqBase.parameterless_type(θ), [t, W])))
else
phi = (t, W, θ) -> u0 +
(t - t0) *
re(θ)(adapt(DiffEqBase.parameterless_type(θ), [t, W]))
end
end
phi, init_params = generate_phi_θ_rode(chain, t0, u0, init_params)

if autodiff
# dfdx = (t,W,θ) -> ForwardDiff.derivative(t->phi(t,θ),t)
else
dfdx = (t, W, θ) -> (phi(t + sqrt(eps(t)), W, θ) - phi(t, W, θ)) / sqrt(eps(t))
end
strategy = isnothing(alg.strategy) ? GridTraining(dt) : alg.strategy
batch = isnothing(alg.batch) ? false : alg.batch

function inner_loss(t, W, θ)
sum(abs, dfdx(t, W, θ) - f(phi(t, W, θ), p, t, W))
end
Wprob = NoiseProblem(Wg, tspan)
Wsol = solve(Wprob; dt = dt)
W = NoiseGrid(ts, Wsol.W)
function loss(θ)
sum(abs2, inner_loss(ts[i], W.W[i], θ) for i in 1:length(ts)) # sum(abs2,phi(tspan[1],θ) - u0)
W_prob = NoiseProblem(W, tspan)
W_en = EnsembleProblem(W_prob)
W_sim = solve(W_en; dt = dt, trajectories = trajectories)
W_bf = Zygote.Buffer(rand(length(W_sim), length(W_sim[1])))
for (i, sol) in enumerate(W_sim)
W_bf[i, :] = sol
end
optf = generate_loss(strategy, phi, f, autodiff::Bool, tspan, W_bf, p, batch)

iteration = 0
callback = function (p, l)
Wprob = NoiseProblem(Wg, tspan)
Wsol = solve(Wprob; dt = dt)
W = NoiseGrid(ts, Wsol.W)
verbose && println("Current loss is: $l")
iteration += 1
verbose && println("Current loss is: $l, Iteration: $iteration")
l < abstol
end
#res = DiffEqFlux.sciml_train(loss, init_params, opt; cb = callback, maxiters = maxiters,
# alg.kwargs...)

#solutions at timepoints
noiseproblem = NoiseProblem(Wg, tspan)
W = solve(noiseproblem; dt = dt)
if u0 isa Number
u = [(phi(ts[i], W.W[i], res.minimizer)) for i in 1:length(ts)]
else
u = [(phi(ts[i], W.W[i], res.minimizer)) for i in 1:length(ts)]
end

sol = DiffEqBase.build_solution(prob, alg, ts, u, W = W, calculate_error = false)
DiffEqBase.has_analytic(prob.f) &&
DiffEqBase.calculate_solution_errors!(sol; timeseries_errors = true,
dense_errors = false)
sol
optprob = OptimizationProblem(optf, init_params)
res = solve(optprob, opt; callback, maxiters, alg.kwargs...)

res, (t, W) -> phi(t, W, res.u)
end #solve
Loading