Skip to content

Commit

Permalink
refactor: generalize _initialize_dae! to use SciMLBase implementations
Browse files Browse the repository at this point in the history
  • Loading branch information
AayushSabharwal committed Nov 19, 2024
1 parent 34a49c1 commit a5097b8
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 107 deletions.
2 changes: 1 addition & 1 deletion lib/OrdinaryDiffEqCore/src/OrdinaryDiffEqCore.jl
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ using DiffEqBase: check_error!, @def, _vec, _reshape

using FastBroadcast: @.., True, False

using SciMLBase: NoInit, CheckInit, _unwrap_val
using SciMLBase: NoInit, CheckInit, OverrideInit, AbstractDEProblem, _unwrap_val

import SciMLBase: alg_order

Expand Down
123 changes: 17 additions & 106 deletions lib/OrdinaryDiffEqCore/src/initialize_dae.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,6 @@ function BrownFullBasicInit(; abstol = 1e-10, nlsolve = nothing)
end
BrownFullBasicInit(abstol) = BrownFullBasicInit(; abstol = abstol, nlsolve = nothing)

struct OverrideInit{T, F} <: DiffEqBase.DAEInitializationAlgorithm
abstol::T
nlsolve::F
end

function OverrideInit(; abstol = 1e-10, nlsolve = nothing)
OverrideInit(abstol, nlsolve)
end
OverrideInit(abstol) = OverrideInit(; abstol = abstol, nlsolve = nothing)

## Notes

#=
Expand Down Expand Up @@ -143,19 +133,15 @@ end

## NoInit

function _initialize_dae!(integrator, prob::Union{ODEProblem, DAEProblem},
function _initialize_dae!(integrator, prob::AbstractDEProblem,
alg::NoInit, x::Union{Val{true}, Val{false}})
end

## OverrideInit

function _initialize_dae!(integrator, prob::Union{ODEProblem, DAEProblem},
function _initialize_dae!(integrator, prob::AbstractDEProblem,
alg::OverrideInit, isinplace::Union{Val{true}, Val{false}})
initializeprob = prob.f.initializeprob

if SciMLBase.has_update_initializeprob!(prob.f)
prob.f.update_initializeprob!(initializeprob, prob)
end
initializeprob = prob.f.initialization_data.initializeprob

# If it doesn't have autodiff, assume it comes from symbolic system like ModelingToolkit
# Since then it's the case of not a DAE but has initializeprob
Expand All @@ -168,105 +154,30 @@ function _initialize_dae!(integrator, prob::Union{ODEProblem, DAEProblem},
true
end

alg = default_nlsolve(alg.nlsolve, isinplace, initializeprob.u0, initializeprob, isAD)
nlsol = solve(initializeprob, alg, abstol = integrator.opts.abstol, reltol = integrator.opts.reltol)
nlsolve_alg = default_nlsolve(alg.nlsolve, isinplace, initializeprob.u0, initializeprob, isAD)

u0, p, success = SciMLBase.get_initial_values(prob, prob.f, integrator, alg, isinplace; nlsolve_alg, abstol = integrator.opts.abstol, reltol = integrator.opts.reltol)

if isinplace === Val{true}()
integrator.u .= prob.f.initializeprobmap(nlsol)
integrator.u .= u0
elseif isinplace === Val{false}()
integrator.u = prob.f.initializeprobmap(nlsol)
integrator.u = u0
else
error("Unreachable reached. Report this error.")
end
if SciMLBase.has_initializeprobpmap(prob.f)
integrator.p = prob.f.initializeprobpmap(prob, nlsol)
sol = integrator.sol
@reset sol.prob.p = integrator.p
integrator.sol = sol
end
integrator.p = p
sol = integrator.sol
@reset sol.prob.p = integrator.p
integrator.sol = sol

if nlsol.retcode != ReturnCode.Success
if !success
integrator.sol = SciMLBase.solution_new_retcode(integrator.sol,
ReturnCode.InitialFailure)
end
end

## CheckInit
struct CheckInitFailureError <: Exception
normresid::Any
abstol::Any
end

function Base.showerror(io::IO, e::CheckInitFailureError)
print(io,
"CheckInit specified but initialization not satisifed. normresid = $(e.normresid) > abstol = $(e.abstol)")
end

function _initialize_dae!(integrator, prob::ODEProblem, alg::CheckInit,
isinplace::Val{true})
@unpack p, t, f = integrator
M = integrator.f.mass_matrix
tmp = first(get_tmp_cache(integrator))
u0 = integrator.u

algebraic_vars = [all(iszero, x) for x in eachcol(M)]
algebraic_eqs = [all(iszero, x) for x in eachrow(M)]
(iszero(algebraic_vars) || iszero(algebraic_eqs)) && return
update_coefficients!(M, u0, p, t)
f(tmp, u0, p, t)
tmp .= ArrayInterface.restructure(tmp, algebraic_eqs .* _vec(tmp))

normresid = integrator.opts.internalnorm(tmp, t)
if normresid > integrator.opts.abstol
throw(CheckInitFailureError(normresid, integrator.opts.abstol))
end
end

function _initialize_dae!(integrator, prob::ODEProblem, alg::CheckInit,
isinplace::Val{false})
@unpack p, t, f = integrator
u0 = integrator.u
M = integrator.f.mass_matrix

algebraic_vars = [all(iszero, x) for x in eachcol(M)]
algebraic_eqs = [all(iszero, x) for x in eachrow(M)]
(iszero(algebraic_vars) || iszero(algebraic_eqs)) && return
update_coefficients!(M, u0, p, t)
du = f(u0, p, t)
resid = _vec(du)[algebraic_eqs]

normresid = integrator.opts.internalnorm(resid, t)
if normresid > integrator.opts.abstol
throw(CheckInitFailureError(normresid, integrator.opts.abstol))
end
end

function _initialize_dae!(integrator, prob::DAEProblem,
alg::CheckInit, isinplace::Val{true})
@unpack p, t, f = integrator
u0 = integrator.u
resid = get_tmp_cache(integrator)[2]

f(resid, integrator.du, u0, p, t)
normresid = integrator.opts.internalnorm(resid, t)
if normresid > integrator.opts.abstol
throw(CheckInitFailureError(normresid, integrator.opts.abstol))
end
end

function _initialize_dae!(integrator, prob::DAEProblem,
alg::CheckInit, isinplace::Val{false})
@unpack p, t, f = integrator
u0 = integrator.u

nlequation_oop = u -> begin
f((u - u0) / dt, u, p, t)
end

nlequation = (u, _) -> nlequation_oop(u)

resid = f(integrator.du, u0, p, t)
normresid = integrator.opts.internalnorm(resid, t)
if normresid > integrator.opts.abstol
throw(CheckInitFailureError(normresid, integrator.opts.abstol))
end
function _initialize_dae!(integrator, prob::AbstractDEProblem, alg::CheckInit,
isinplace::Union{Val{true}, Val{false}})
SciMLBase.get_initial_values(prob, integrator, prob.f, alg, isinplace; abstol = integrator.opts.abstol)
end

0 comments on commit a5097b8

Please sign in to comment.