Skip to content

Commit

Permalink
test: test lazy initialization in remake for supported problem types
Browse files Browse the repository at this point in the history
  • Loading branch information
AayushSabharwal committed Dec 9, 2024
1 parent 58f6852 commit fbcc39e
Showing 1 changed file with 40 additions and 6 deletions.
46 changes: 40 additions & 6 deletions test/downstream/modelingtoolkit_remake.jl
Original file line number Diff line number Diff line change
Expand Up @@ -337,11 +337,45 @@ end
@test sccprob4.p !== sccprob4.probs[2].p
end

# TODO: Rewrite this test when MTK build initialization for everything
@testset "Lazy initialization" begin
@variables x(t) [guess = 1.0] y(t) [guess = 1.0]
@parameters p=missing [guess = 1.0]
@mtkbuild sys = ODESystem([D(x) ~ x, x + y ~ p], t)
prob = ODEProblem(sys, [x => 1.0, y => 1.0], (0.0, 1.0))
prob2 = remake(prob; u0 = [x => 2.0])
@test prob2.ps[p] 3.0
@variables _x(..) [guess = 1.0] y(t) [guess = 1.0]
@parameters p=1.0 [guess = 1.0]
@brownian a
x = _x(t)

initprob = NonlinearProblem(nothing) do args...
return 0.0
end
iprobmap = (_...) -> [1.0, 1.0]
iprobpmap = function (orig, sol)
ps = parameter_values(orig)
setp(orig, p)(ps, 3.0)
return ps
end
initdata = SciMLBase.OverrideInitData(initprob, nothing, iprobmap, iprobpmap)
@test SciMLBase.is_trivial_initialization(initdata)

@testset "$Problem" for (SystemT, rhss, Problem, Func) in [
(ODESystem, 0.0, ODEProblem, ODEFunction),
(System, a, SDEProblem, SDEFunction),
(ODESystem, _x(t - 0.1), DDEProblem, DDEFunction),
(System, _x(t - 0.1) + a, SDDEProblem, SDDEFunction),
(NonlinearSystem, y + 2, NonlinearProblem, NonlinearFunction),
(NonlinearSystem, y + 2, NonlinearLeastSquaresProblem, NonlinearFunction)
]
is_nlsolve = SystemT == NonlinearSystem
D = is_nlsolve ? (v) -> v^3 : Differential(t)
sys_args = is_nlsolve ? () : (t,)
prob_args = is_nlsolve ? () : ((0.0, 1.0),)

@mtkbuild sys = SystemT([D(x) ~ x + rhss, x + y ~ p], sys_args...)
prob = Problem(sys, [x => 1.0, y => 1.0], prob_args...)
func_args = isdefined(prob.f, :g) ? (prob.f.g,) : ()
func = Func{true, SciMLBase.FullSpecialize}(
prob.f.f, func_args...; initialization_data = initdata, sys = prob.f.sys)
prob2 = remake(prob; f = func)
@test SciMLBase.is_trivial_initialization(prob2)
@test prob2.ps[p] 3.0
end
end

0 comments on commit fbcc39e

Please sign in to comment.