Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: enable OverrideInit to solve for du0 of DAEProblems #879

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 18 additions & 4 deletions src/initialization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

A collection of all the data required for `OverrideInit`.
"""
struct OverrideInitData{IProb, UIProb, IProbMap, IProbPmap}
struct OverrideInitData{IProb, UIProb, IProbMap, IProbPmap, IProbDu0Map}
"""
The `AbstractNonlinearProblem` to solve for initialization.
"""
Expand All @@ -29,12 +29,18 @@ struct OverrideInitData{IProb, UIProb, IProbMap, IProbPmap}
initialized will be returned as-is.
"""
initializeprobpmap::IProbPmap
"""
A function which takes the solution of `initializeprob` and returns the
`du0` vector of the original problem.
"""
initializeprob_du0map::IProbDu0Map

function OverrideInitData(initprob::I, update_initprob!::J, initprobmap::K,
initprobpmap::L) where {I, J, K, L}
initprobpmap::L, initprob_du0map::M = nothing) where {I, J, K, L, M}
@assert initprob isa
Union{SCCNonlinearProblem, NonlinearProblem, NonlinearLeastSquaresProblem}
return new{I, J, K, L}(initprob, update_initprob!, initprobmap, initprobpmap)
return new{I, J, K, L, M}(
initprob, update_initprob!, initprobmap, initprobpmap, initprob_du0map)
end
end

Expand Down Expand Up @@ -171,9 +177,12 @@ Keyword arguments:
provided to the `OverrideInit` constructor takes priority over this keyword argument.
If the former is `nothing`, this keyword argument will be used. If it is also not provided,
an error will be thrown.
- `return_du0`: Whether to use `initializeprob_du0map` (if present) and return
`du0, u0, p, success`.
"""
function get_initial_values(prob, valp, f, alg::OverrideInit,
iip::Union{Val{true}, Val{false}}; nlsolve_alg = nothing, abstol = nothing, reltol = nothing, kwargs...)
iip::Union{Val{true}, Val{false}}; nlsolve_alg = nothing, abstol = nothing,
reltol = nothing, return_du0 = false, kwargs...)
u0 = state_values(valp)
p = parameter_values(valp)

Expand Down Expand Up @@ -214,5 +223,10 @@ function get_initial_values(prob, valp, f, alg::OverrideInit,
p = initdata.initializeprobpmap(valp, nlsol)
end

if return_du0
du0 = initdata.initializeprob_du0map === nothing ? nothing : initdata.initializeprob_du0map(nlsol)
return du0, u0, p, SciMLBase.successful_retcode(nlsol)
end

return u0, p, SciMLBase.successful_retcode(nlsol)
end
62 changes: 62 additions & 0 deletions test/initialization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -229,4 +229,66 @@ end
@test p ≈ 0.0
@test success
end

@testset "DAEProblem" begin
function daerhs(du, u, p, t)
return [u[1] * t + p, u[1]^2 - u[2]^2]
end
# unknowns are u[2], p, D(u[1]), D(u[2]). Parameters are u[1], t
initprob = NonlinearProblem([1.0, 1.0, 1.0, 1.0], [1.0, 0.0]) do x, _p
u2, p, du1, du2 = x
u1, t = _p
return [u1^3 - u2^3, p^2 - 2p + 1, du1 - u1 * t - p, 2u1 * du1 - 2u2 * du2]
end

update_initializeprob! = function (iprob, integ)
iprob.p[1] = integ.u[1]
iprob.p[2] = integ.t
end
initprobmap = function (nlsol)
return [parameter_values(nlsol)[1], nlsol.u[1]]
end
initprobpmap = function (_, nlsol)
return nlsol.u[2]
end
initprob_du0map = function (nlsol)
return nlsol.u[3:4]
end
initialization_data = SciMLBase.OverrideInitData(
initprob, update_initializeprob!, initprobmap, initprobpmap, initprob_du0map)
fn = DAEFunction(daerhs; initialization_data)
prob = DAEProblem(fn, [0.0, 0.0], [2.0, 0.0], (0.0, 1.0), 0.0)
integ = init(prob, DImplicitEuler(); initializealg = NoInit())

initialization_data2 = SciMLBase.OverrideInitData(
initprob, update_initializeprob!, initprobmap, initprobpmap)
fn2 = DAEFunction(daerhs; initialization_data = initialization_data2)
prob2 = DAEProblem(fn2, [0.0, 0.0], [2.0, 0.0], (0.0, 1.0), 0.0)
integ2 = init(prob2, DImplicitEuler(); initializealg = NoInit())

nlsolve_alg = FastShortcutNonlinearPolyalg()
@testset "Doesn't return `du0` by default" begin
@test length(SciMLBase.get_initial_values(
prob, integ, fn, SciMLBase.OverrideInit(),
Val(false); nlsolve_alg, abstol, reltol)) == 3
end
@testset "`du0 === nothing` if missing `du0map`" begin
du0, u0, p, success = SciMLBase.get_initial_values(
prob2, integ2, fn2, SciMLBase.OverrideInit(), Val(false);
nlsolve_alg, abstol, reltol, return_du0 = true)
@test du0 === nothing
@test u0 ≈ [2.0, 2.0]
@test p ≈ 1.0
@test success
end
@testset "With `return_du0 = true`" begin
du0, u0, p, success = SciMLBase.get_initial_values(
prob, integ, fn, SciMLBase.OverrideInit(), Val(false);
nlsolve_alg, abstol, reltol, return_du0 = true)
@test du0 ≈ [1.0, 1.0]
@test u0 ≈ [2.0, 2.0]
@test p ≈ 1.0
@test success
end
end
end
Loading