Skip to content

Commit

Permalink
feat: do not require nlsolve_alg for trivial OverrideInit
Browse files Browse the repository at this point in the history
  • Loading branch information
AayushSabharwal committed Dec 3, 2024
1 parent 86aa145 commit a4fd6d8
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 19 deletions.
46 changes: 27 additions & 19 deletions src/initialization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,9 @@ Keyword arguments:
provided to the `OverrideInit` constructor takes priority over this keyword argument.
If the former is `nothing`, this keyword argument will be used. If it is also not provided,
an error will be thrown.
In case the initialization problem is trivial, `nlsolve_alg`, `abstol` and `reltol` are
not required.
"""
function get_initial_values(prob, valp, f, alg::OverrideInit,
iip::Union{Val{true}, Val{false}}; nlsolve_alg = nothing, abstol = nothing, reltol = nothing, kwargs...)
Expand All @@ -185,35 +188,40 @@ function get_initial_values(prob, valp, f, alg::OverrideInit,
initdata::OverrideInitData = f.initialization_data
initprob = initdata.initializeprob

nlsolve_alg = something(nlsolve_alg, alg.nlsolve, Some(nothing))
if nlsolve_alg === nothing && state_values(initprob) !== nothing
throw(OverrideInitMissingAlgorithm())
end

if initdata.update_initializeprob! !== nothing
initdata.update_initializeprob!(initprob, valp)
end

if alg.abstol !== nothing
_abstol = alg.abstol
elseif abstol !== nothing
_abstol = abstol
else
throw(OverrideInitNoTolerance(:abstol))
end
if alg.reltol !== nothing
_reltol = alg.reltol
elseif reltol !== nothing
_reltol = reltol
if state_values(initprob) === nothing
nlsol = initprob
success = true
else
throw(OverrideInitNoTolerance(:reltol))
nlsolve_alg = something(nlsolve_alg, alg.nlsolve, Some(nothing))
if nlsolve_alg === nothing && state_values(initprob) !== nothing
throw(OverrideInitMissingAlgorithm())
end
if alg.abstol !== nothing
_abstol = alg.abstol
elseif abstol !== nothing
_abstol = abstol
else
throw(OverrideInitNoTolerance(:abstol))
end
if alg.reltol !== nothing
_reltol = alg.reltol
elseif reltol !== nothing
_reltol = reltol
else
throw(OverrideInitNoTolerance(:reltol))
end
nlsol = solve(initprob, nlsolve_alg; abstol = _abstol, reltol = _reltol)
success = SciMLBase.successful_retcode(nlsol)
end
nlsol = solve(initprob, nlsolve_alg; abstol = _abstol, reltol = _reltol)

u0 = initdata.initializeprobmap(nlsol)
if initdata.initializeprobpmap !== nothing
p = initdata.initializeprobpmap(valp, nlsol)
end

return u0, p, SciMLBase.successful_retcode(nlsol)
return u0, p, success
end
26 changes: 26 additions & 0 deletions test/initialization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -244,4 +244,30 @@ end
@test p 0.0
@test success
end

@testset "Trivial initialization" begin
initprob = NonlinearProblem(Returns(nothing), nothing, [1.0])
update_initializeprob! = function (iprob, integ)
iprob.p[1] = integ.u[1]
end
initprobmap = function (nlsol)
u1 = parameter_values(nlsol)[1]
return [u1, u1]
end
initprobpmap = function (_, nlsol)
return 0.0
end
initialization_data = SciMLBase.OverrideInitData(
initprob, update_initializeprob!, initprobmap, initprobpmap)
fn = ODEFunction(rhs2; initialization_data)
prob = ODEProblem(fn, [2.0, 0.0], (0.0, 1.0), 0.0)
integ = init(prob; initializealg = NoInit())

u0, p, success = SciMLBase.get_initial_values(
prob, integ, fn, SciMLBase.OverrideInit(), Val(false)
)
@test u0 [2.0, 2.0]
@test p 0.0
@test success
end
end

0 comments on commit a4fd6d8

Please sign in to comment.