Skip to content

Commit

Permalink
feat: do not require nlsolve_alg for trivial OverrideInit
Browse files Browse the repository at this point in the history
  • Loading branch information
AayushSabharwal committed Nov 28, 2024
1 parent a5ee8e9 commit a6a4a1f
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 14 deletions.
37 changes: 23 additions & 14 deletions src/initialization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,9 @@ Keyword arguments:
provided to the `OverrideInit` constructor takes priority over this keyword argument.
If the former is `nothing`, this keyword argument will be used. If it is also not provided,
an error will be thrown.
In case the initialization problem is trivial, `nlsolve_alg`, `abstol` and `reltol` are
not required.
"""
function get_initial_values(prob, valp, f, alg::OverrideInit,
iip::Union{Val{true}, Val{false}}; nlsolve_alg = nothing, abstol = nothing, reltol = nothing, kwargs...)
Expand All @@ -193,26 +196,32 @@ function get_initial_values(prob, valp, f, alg::OverrideInit,
initdata.update_initializeprob!(initprob, valp)
end

if alg.abstol !== nothing
_abstol = alg.abstol
elseif abstol !== nothing
_abstol = abstol
else
throw(OverrideInitNoTolerance(:abstol))
end
if alg.reltol !== nothing
_reltol = alg.reltol
elseif reltol !== nothing
_reltol = reltol
if state_values(initprob) === nothing
nlsol = initprob
success = true
else
throw(OverrideInitNoTolerance(:reltol))
if alg.abstol !== nothing
_abstol = alg.abstol
elseif abstol !== nothing
_abstol = abstol
else
throw(OverrideInitNoTolerance(:abstol))
end
if alg.reltol !== nothing
_reltol = alg.reltol
elseif reltol !== nothing
_reltol = reltol
else
throw(OverrideInitNoTolerance(:reltol))
end
nlsol = solve(initprob, nlsolve_alg; abstol = _abstol, reltol = _reltol)
success = SciMLBase.successful_retcode(nlsol)
end
nlsol = solve(initprob, nlsolve_alg; abstol = _abstol, reltol = _reltol)

u0 = initdata.initializeprobmap(nlsol)
if initdata.initializeprobpmap !== nothing
p = initdata.initializeprobpmap(valp, nlsol)
end

return u0, p, SciMLBase.successful_retcode(nlsol)
return u0, p, success
end
26 changes: 26 additions & 0 deletions test/initialization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -229,4 +229,30 @@ end
@test p 0.0
@test success
end

@testset "Trivial initialization" begin
initprob = NonlinearProblem(Returns(nothing), nothing, [1.0])
update_initializeprob! = function (iprob, integ)
iprob.p[1] = integ.u[1]
end
initprobmap = function (nlsol)
u1 = parameter_values(nlsol)[1]
return [u1, u1]
end
initprobpmap = function (_, nlsol)
return 0.0
end
initialization_data = SciMLBase.OverrideInitData(
initprob, update_initializeprob!, initprobmap, initprobpmap)
fn = ODEFunction(rhs2; initialization_data)
prob = ODEProblem(fn, [2.0, 0.0], (0.0, 1.0), 0.0)
integ = init(prob; initializealg = NoInit())

u0, p, success = SciMLBase.get_initial_values(
prob, integ, fn, SciMLBase.OverrideInit(), Val(false)
)
@test u0 [2.0, 2.0]
@test p 0.0
@test success
end
end

0 comments on commit a6a4a1f

Please sign in to comment.