From a4fd6d8428775ca83b77dcad4142bd4e8fbdf55c Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Thu, 28 Nov 2024 16:05:01 +0530 Subject: [PATCH] feat: do not require `nlsolve_alg` for trivial `OverrideInit` --- src/initialization.jl | 46 +++++++++++++++++++++++++----------------- test/initialization.jl | 26 ++++++++++++++++++++++++ 2 files changed, 53 insertions(+), 19 deletions(-) diff --git a/src/initialization.jl b/src/initialization.jl index 4e976e168..1179e606c 100644 --- a/src/initialization.jl +++ b/src/initialization.jl @@ -172,6 +172,9 @@ Keyword arguments: provided to the `OverrideInit` constructor takes priority over this keyword argument. If the former is `nothing`, this keyword argument will be used. If it is also not provided, an error will be thrown. + +In case the initialization problem is trivial, `nlsolve_alg`, `abstol` and `reltol` are +not required. """ function get_initial_values(prob, valp, f, alg::OverrideInit, iip::Union{Val{true}, Val{false}}; nlsolve_alg = nothing, abstol = nothing, reltol = nothing, kwargs...) @@ -185,35 +188,40 @@ function get_initial_values(prob, valp, f, alg::OverrideInit, initdata::OverrideInitData = f.initialization_data initprob = initdata.initializeprob - nlsolve_alg = something(nlsolve_alg, alg.nlsolve, Some(nothing)) - if nlsolve_alg === nothing && state_values(initprob) !== nothing - throw(OverrideInitMissingAlgorithm()) - end - if initdata.update_initializeprob! !== nothing initdata.update_initializeprob!(initprob, valp) end - if alg.abstol !== nothing - _abstol = alg.abstol - elseif abstol !== nothing - _abstol = abstol - else - throw(OverrideInitNoTolerance(:abstol)) - end - if alg.reltol !== nothing - _reltol = alg.reltol - elseif reltol !== nothing - _reltol = reltol + if state_values(initprob) === nothing + nlsol = initprob + success = true else - throw(OverrideInitNoTolerance(:reltol)) + nlsolve_alg = something(nlsolve_alg, alg.nlsolve, Some(nothing)) + if nlsolve_alg === nothing && state_values(initprob) !== nothing + throw(OverrideInitMissingAlgorithm()) + end + if alg.abstol !== nothing + _abstol = alg.abstol + elseif abstol !== nothing + _abstol = abstol + else + throw(OverrideInitNoTolerance(:abstol)) + end + if alg.reltol !== nothing + _reltol = alg.reltol + elseif reltol !== nothing + _reltol = reltol + else + throw(OverrideInitNoTolerance(:reltol)) + end + nlsol = solve(initprob, nlsolve_alg; abstol = _abstol, reltol = _reltol) + success = SciMLBase.successful_retcode(nlsol) end - nlsol = solve(initprob, nlsolve_alg; abstol = _abstol, reltol = _reltol) u0 = initdata.initializeprobmap(nlsol) if initdata.initializeprobpmap !== nothing p = initdata.initializeprobpmap(valp, nlsol) end - return u0, p, SciMLBase.successful_retcode(nlsol) + return u0, p, success end diff --git a/test/initialization.jl b/test/initialization.jl index 1ea1da694..7d5dfb01d 100644 --- a/test/initialization.jl +++ b/test/initialization.jl @@ -244,4 +244,30 @@ end @test p ≈ 0.0 @test success end + + @testset "Trivial initialization" begin + initprob = NonlinearProblem(Returns(nothing), nothing, [1.0]) + update_initializeprob! = function (iprob, integ) + iprob.p[1] = integ.u[1] + end + initprobmap = function (nlsol) + u1 = parameter_values(nlsol)[1] + return [u1, u1] + end + initprobpmap = function (_, nlsol) + return 0.0 + end + initialization_data = SciMLBase.OverrideInitData( + initprob, update_initializeprob!, initprobmap, initprobpmap) + fn = ODEFunction(rhs2; initialization_data) + prob = ODEProblem(fn, [2.0, 0.0], (0.0, 1.0), 0.0) + integ = init(prob; initializealg = NoInit()) + + u0, p, success = SciMLBase.get_initial_values( + prob, integ, fn, SciMLBase.OverrideInit(), Val(false) + ) + @test u0 ≈ [2.0, 2.0] + @test p ≈ 0.0 + @test success + end end